From 51d3557f8f8fe23c144b98da77a0a7613e021407 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Apr 2024 13:00:16 -0700 Subject: [PATCH 001/360] pulled in sync generator from previous branch --- sync_surface_generator.py | 523 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 523 insertions(+) create mode 100644 sync_surface_generator.py diff --git a/sync_surface_generator.py b/sync_surface_generator.py new file mode 100644 index 000000000..5eea38425 --- /dev/null +++ b/sync_surface_generator.py @@ -0,0 +1,523 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import inspect +import ast +import textwrap +import time +import queue +import os +import threading +import concurrent.futures + +from black import format_str, FileMode +import autoflake +""" +This module allows us to generate a synchronous API surface from our asyncio surface. +""" + +# This map defines replacements for asyncio API calls +asynciomap = { + "sleep": ({"time": time}, "time.sleep"), + "Queue": ({"queue": queue}, "queue.Queue"), + "Condition": ({"threading": threading}, "threading.Condition"), + "Future": ({"concurrent.futures": concurrent.futures}, "concurrent.futures.Future"), +} + +# This map defines find/replace pairs for the generated code and docstrings +# replace async calls with corresponding sync ones +name_map = { + "__anext__": "__next__", + "__aiter__": "__iter__", + "__aenter__": "__enter__", + "__aexit__": "__exit__", + "aclose": "close", + "AsyncIterable": "Iterable", + "AsyncIterator": "Iterator", + "AsyncGenerator": "Generator", + "StopAsyncIteration": "StopIteration", + "BigtableAsyncClient": "BigtableClient", + "AsyncRetry": "Retry", + "PooledBigtableGrpcAsyncIOTransport": "BigtableGrpcTransport", + "Awaitable": None, + "pytest_asyncio": "pytest", + "AsyncMock": "mock.Mock", + "_ReadRowsOperation": "_ReadRowsOperation_Sync", + "Table": "Table_Sync", + "BigtableDataClient": "BigtableDataClient_Sync", + "ReadRowsIterator": "ReadRowsIterator_Sync", + "_MutateRowsOperation": "_MutateRowsOperation_Sync", + "MutationsBatcher": "MutationsBatcher_Sync", + "_FlowControl": "_FlowControl_Sync", +} + +# This maps classes to the final sync surface location, so they can be instantiated in generated code +concrete_class_map = { + "_ReadRowsOperation": "google.cloud.bigtable._sync._concrete._ReadRowsOperation_Sync_Concrete", + "Table": "google.cloud.bigtable._sync._concrete.Table_Sync_Concrete", + "BigtableDataClient": "google.cloud.bigtable._sync._concrete.BigtableDataClient_Sync_Concrete", + "ReadRowsIterator": "google.cloud.bigtable._sync._concrete.ReadRowsIterator_Sync_Concrete", + "_MutateRowsOperation": "google.cloud.bigtable._sync._concrete._MutateRowsOperation_Sync_Concrete", + "MutationsBatcher": "google.cloud.bigtable._sync._concrete.MutationsBatcher_Threaded", + "_FlowControl": "google.cloud.bigtable._sync._concrete._FlowControl_Sync_Concrete", +} + +# This map defines import replacements for the generated code +# Note that "import ... as" statements keep the same name +import_map = { + ("google.api_core", "retry_async"): ("google.api_core", "retry"), + ( + "google.cloud.bigtable_v2.services.bigtable.async_client", + "BigtableAsyncClient", + ): ("google.cloud.bigtable_v2.services.bigtable.client", "BigtableClient"), + ("typing", "AsyncIterable"): ("typing", "Iterable"), + ("typing", "AsyncIterator"): ("typing", "Iterator"), + ("typing", "AsyncGenerator"): ("typing", "Generator"), + ( + "google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio", + "PooledBigtableGrpcAsyncIOTransport", + ): ( + "google.cloud.bigtable_v2.services.bigtable.transports.grpc", + "BigtableGrpcTransport", + ), + ("grpc.aio", "Channel"): ("grpc", "Channel"), +} + +# methods that are replaced with an empty implementation in sync surface +pass_methods=["__init__async__", "_prepare_stream", "_manage_channel", "_register_instance", "__init__transport__", "start_background_channel_refresh", "_start_idle_timer"] +# methods that are dropped from sync surface +drop_methods=["_buffer_to_generator", "_generator_to_buffer", "_idle_timeout_coroutine"] +# methods that raise a NotImplementedError in sync surface +error_methods=[] + +class AsyncToSyncTransformer(ast.NodeTransformer): + """ + This class is used to transform async classes into sync classes. + Generated classes are abstract, and must be subclassed to be used. + This is to ensure any required customizations from + outside of this autogeneration system are always applied + """ + + def __init__(self, *, import_replacements=None, asyncio_replacements=None, name_replacements=None, concrete_class_map=None, drop_methods=None, pass_methods=None, error_methods=None): + """ + Args: + - import_replacements: dict of (module, name) to (module, name) replacement import statements + For example, {("foo", "bar"): ("baz", "qux")} will replace "from foo import bar" with "from baz import qux" + - asyncio_replacements: dict of asyncio function names to the module/function name to replace them with + - name_replacements: dict of names to replace directly in the source code and docstrings + - concrete_class_map: dict of the concrete class names for all autogenerated classes + in this module, so that concrete versions of each class can be instantiated + - drop_methods: list of method names to drop from the class + - pass_methods: list of method names to replace with "pass" in the class + - error_methods: list of method names to replace with "raise NotImplementedError" in the class + """ + self.globals = {} + self.import_replacements = import_replacements or {} + self.asyncio_replacements = asyncio_replacements or {} + self.name_replacements = name_replacements or {} + self.concrete_class_map = concrete_class_map or {} + self.drop_methods = drop_methods or [] + self.pass_methods = pass_methods or [] + self.error_methods = error_methods or [] + + def update_docstring(self, docstring): + """ + Update docstring toreplace any key words in the name_replacements dict + """ + if not docstring: + return docstring + for key_word, replacement in self.name_replacements.items(): + docstring = docstring.replace(f" {key_word} ", f" {replacement} ") + if "\n" in docstring: + # if multiline docstring, add linebreaks to put the """ on a separate line + docstring = "\n" + docstring + "\n\n" + return docstring + + def visit_FunctionDef(self, node): + """ + Re-use replacement logic for Async functions + """ + return self.visit_AsyncFunctionDef(node) + + def visit_AsyncFunctionDef(self, node): + """ + Replace async functions with sync functions + """ + # replace docstring + docstring = self.update_docstring(ast.get_docstring(node)) + if isinstance(node.body[0], ast.Expr) and isinstance( + node.body[0].value, ast.Str + ): + node.body[0].value.s = docstring + # drop or replace body as needed + if node.name in self.drop_methods: + return None + elif node.name in self.pass_methods: + # replace with pass + node.body = [ast.Expr(value=ast.Str(s="Implementation purposely removed in sync mode"))] + elif node.name in self.error_methods: + self._create_error_node(node, "Function marked as unsupported in sync mode") + else: + # check if the function contains non-replaced usage of asyncio + func_ast = ast.parse(ast.unparse(node)) + has_asyncio_calls = any( + isinstance(n, ast.Call) + and isinstance(n.func, ast.Attribute) + and isinstance(n.func.value, ast.Name) + and n.func.value.id == "asyncio" + and n.func.attr not in self.asyncio_replacements + for n in ast.walk(func_ast) + ) + if has_asyncio_calls: + self._create_error_node( + node, + "Corresponding Async Function contains unhandled asyncio calls", + ) + # remove pytest.mark.asyncio decorator + if hasattr(node, "decorator_list"): + is_asyncio_decorator = lambda d: all(x in ast.dump(d) for x in ["pytest", "mark", "asyncio"]) + node.decorator_list = [ + d for d in node.decorator_list if not is_asyncio_decorator(d) + ] + return ast.copy_location( + ast.FunctionDef( + self.name_replacements.get(node.name, node.name), + self.visit(node.args), + [self.visit(stmt) for stmt in node.body], + [self.visit(stmt) for stmt in node.decorator_list], + node.returns and self.visit(node.returns), + ), + node, + ) + + def visit_Call(self, node): + # name replacement for class method calls + if isinstance(node.func, ast.Attribute) and isinstance( + node.func.value, ast.Name + ): + node.func.value.id = self.name_replacements.get(node.func.value.id, node.func.value.id) + # when initializing an auto-generated sync class, replace the class name with the patched version + if isinstance(node.func, ast.Name) and node.func.id in self.concrete_class_map: + node.func.id = self.concrete_class_map[node.func.id] + return ast.copy_location( + ast.Call( + self.visit(node.func), + [self.visit(arg) for arg in node.args], + [self.visit(keyword) for keyword in node.keywords], + ), + node, + ) + + def visit_Await(self, node): + return self.visit(node.value) + + def visit_Attribute(self, node): + if ( + isinstance(node.value, ast.Name) + and isinstance(node.value.ctx, ast.Load) + and node.value.id == "asyncio" + and node.attr in self.asyncio_replacements + ): + g, replacement = self.asyncio_replacements[node.attr] + self.globals.update(g) + return ast.copy_location(ast.parse(replacement, mode="eval").body, node) + elif isinstance(node, ast.Attribute) and node.attr in self.name_replacements: + new_node = ast.copy_location( + ast.Attribute(node.value, self.name_replacements[node.attr], node.ctx), node + ) + return new_node + return node + + def visit_Name(self, node): + node.id = self.name_replacements.get(node.id, node.id) + return node + + def visit_AsyncFor(self, node): + return ast.copy_location( + ast.For( + self.visit(node.target), + self.visit(node.iter), + [self.visit(stmt) for stmt in node.body], + [self.visit(stmt) for stmt in node.orelse], + ), + node, + ) + + def visit_AsyncWith(self, node): + return ast.copy_location( + ast.With( + [self.visit(item) for item in node.items], + [self.visit(stmt) for stmt in node.body], + ), + node, + ) + + def visit_ListComp(self, node): + # replace [x async for ...] with [x for ...] + new_generators = [] + for generator in node.generators: + if generator.is_async: + new_generators.append( + ast.copy_location( + ast.comprehension( + self.visit(generator.target), + self.visit(generator.iter), + [self.visit(i) for i in generator.ifs], + False, + ), + generator, + ) + ) + else: + new_generators.append(generator) + node.generators = new_generators + return ast.copy_location( + ast.ListComp( + self.visit(node.elt), + [self.visit(gen) for gen in node.generators], + ), + node, + ) + + def visit_Subscript(self, node): + if ( + hasattr(node, "value") + and isinstance(node.value, ast.Name) + and node.value.id == "AsyncGenerator" + and self.name_replacements.get(node.value.id, "") == "Generator" + ): + # Generator has different argument signature than AsyncGenerator + return ast.copy_location( + ast.Subscript( + ast.Name("Generator"), + ast.Index( + ast.Tuple( + [ + self.visit(i) + for i in node.slice.elts + [ast.Constant("Any")] + ] + ) + ), + node.ctx, + ), + node, + ) + elif ( + hasattr(node, "value") + and isinstance(node.value, ast.Name) + and self.name_replacements.get(node.value.id, False) is None + ): + # needed for Awaitable + return self.visit(node.slice) + return ast.copy_location( + ast.Subscript( + self.visit(node.value), + self.visit(node.slice), + node.ctx, + ), + node, + ) + + @staticmethod + def _create_error_node(node, error_msg): + # replace function body with NotImplementedError + exc_node = ast.Call( + func=ast.Name(id="NotImplementedError", ctx=ast.Load()), + args=[ast.Str(s=error_msg)], + keywords=[], + ) + raise_node = ast.Raise(exc=exc_node, cause=None) + node.body = [raise_node] + + + def get_imports(self, filename): + """ + Get the imports from a file, and do a find-and-replace against import_replacements + """ + imports = set() + with open(filename, "r") as f: + full_tree = ast.parse(f.read(), filename) + for node in ast.walk(full_tree): + if isinstance(node, (ast.Import, ast.ImportFrom)): + for alias in node.names: + if isinstance(node, ast.Import): + # import statments + new_import = self.import_replacements.get((alias.name), (alias.name)) + imports.add(ast.parse(f"import {new_import}").body[0]) + else: + # import from statements + # break into individual components + module, name = self.import_replacements.get( + (node.module, alias.name), (node.module, alias.name) + ) + # don't import from same file + if module == ".": + continue + asname_str = f" as {alias.asname}" if alias.asname else "" + imports.add( + ast.parse(f"from {module} import {name}{asname_str}").body[ + 0 + ] + ) + return imports + + +def transform_sync(class_list:list[Type], new_name_format="{}_Sync", add_imports=None, **kwargs): + combined_tree = ast.parse("") + combined_imports = set() + for in_obj in class_list: + filename = inspect.getfile(in_obj) + lines, lineno = inspect.getsourcelines(in_obj) + ast_tree = ast.parse(textwrap.dedent("".join(lines)), filename) + if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): + # update name + old_name = ast_tree.body[0].name + new_name = new_name_format.format(old_name) + ast_tree.body[0].name = new_name + ast.increment_lineno(ast_tree, lineno - 1) + # add ABC as base class + ast_tree.body[0].bases = ast_tree.body[0].bases + [ + ast.Name("ABC", ast.Load()), + ] + # remove top-level imports if any. Add them back later + ast_tree.body = [n for n in ast_tree.body if not isinstance(n, (ast.Import, ast.ImportFrom))] + # transform + transformer = AsyncToSyncTransformer(**kwargs) + transformer.visit(ast_tree) + # find imports + imports = transformer.get_imports(filename) + imports.add(ast.parse("from abc import ABC").body[0]) + # add globals + for g in transformer.globals: + imports.add(ast.parse(f"import {g}").body[0]) + # add locals from file, in case they are needed + if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): + file_basename = os.path.splitext(os.path.basename(filename))[0] + with open(filename, "r") as f: + for node in ast.walk(ast.parse(f.read(), filename)): + if isinstance(node, ast.ClassDef): + imports.add( + ast.parse( + f"from google.cloud.bigtable.{file_basename} import {node.name}" + ).body[0] + ) + # update combined data + combined_tree.body.extend(ast_tree.body) + combined_imports.update(imports) + # add extra imports + if add_imports: + for import_str in add_imports: + combined_imports.add(ast.parse(import_str).body[0]) + # render tree as string of code + import_unique = list(set([ast.unparse(i) for i in combined_imports])) + import_unique.sort() + google, non_google = [], [] + for i in import_unique: + if "google" in i: + google.append(i) + else: + non_google.append(i) + import_str = "\n".join(non_google + [""] + google) + # append clean tree + header = """# Copyright 2023 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # + # This file is automatically generated by sync_surface_generator.py. Do not edit. + """ + full_code = f"{header}\n\n{import_str}\n\n{ast.unparse(combined_tree)}" + full_code = autoflake.fix_code(full_code, remove_all_unused_imports=True) + formatted_code = format_str(full_code, mode=FileMode()) + return formatted_code + + +def generate_full_surface(save_path=None): + """ + Generate a sync surface from all async classes + """ + from google.cloud.bigtable._read_rows import _ReadRowsOperation + from google.cloud.bigtable._mutate_rows import _MutateRowsOperation + from google.cloud.bigtable.client import Table + from google.cloud.bigtable.client import BigtableDataClient + from google.cloud.bigtable.iterators import ReadRowsIterator + from google.cloud.bigtable.mutations_batcher import MutationsBatcher + from google.cloud.bigtable.mutations_batcher import _FlowControl + + conversion_list = [_ReadRowsOperation, Table, BigtableDataClient, ReadRowsIterator, _MutateRowsOperation, MutationsBatcher, _FlowControl] + code = transform_sync(conversion_list, + concrete_class_map=concrete_class_map, name_replacements=name_map, asyncio_replacements=asynciomap, import_replacements=import_map, + pass_methods=pass_methods, drop_methods=drop_methods, error_methods=error_methods, + add_imports=["import google.cloud.bigtable.exceptions as bt_exceptions"], + ) + if save_path is not None: + with open(save_path, "w") as f: + f.write(code) + return code + +def generate_system_tests(save_path=None): + from tests.system import test_system + conversion_list = [test_system] + code = transform_sync(conversion_list, + concrete_class_map=concrete_class_map, name_replacements=name_map, asyncio_replacements=asynciomap, import_replacements=import_map, + drop_methods=["test_read_rows_stream_inactive_timer"], + add_imports=["import google.cloud.bigtable"] + ) + if save_path is not None: + with open(save_path, "w") as f: + f.write(code) + return code + +def generate_unit_tests(test_path="./tests/unit", save_path=None): + """ + Unit tests should typically not be generated using this script. + But this is useful to generate a starting point. + """ + import importlib + if save_path is None: + save_path = os.path.join(test_path, "sync") + updated_list = [] + # find files in test_path + conversion_list = [f for f in os.listdir(test_path) if f.endswith(".py")] + # attempt tp convert each file + for f in conversion_list: + old_code = open(os.path.join(test_path, f), "r").read() + obj = importlib.import_module(f"tests.unit.{f[:-3]}") + new_code = transform_sync([obj], + concrete_class_map=concrete_class_map, name_replacements=name_map, asyncio_replacements=asynciomap, import_replacements=import_map, + add_imports=["import google.cloud.bigtable"] + ) + # only save files with async code + if any([a in new_code for a in asynciomap]) or "async def" in old_code: + with open(os.path.join(save_path, f), "w") as out: + out.write(new_code) + updated_list.append(f) + print(f"Updated {len(updated_list)} files: {updated_list}") + return updated_list + +if __name__ == "__main__": + generate_full_surface(save_path="./google/cloud/bigtable/_sync/_autogen.py") + generate_system_tests("./tests/system/test_system_sync_autogen.py") From fb2ae39074fcd6054c25aee30529781cb8ea0838 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Apr 2024 13:11:34 -0700 Subject: [PATCH 002/360] simplified generator --- sync_surface_generator.py | 132 ++------------------------------------ 1 file changed, 4 insertions(+), 128 deletions(-) diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 5eea38425..47c3075a1 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -30,80 +30,6 @@ This module allows us to generate a synchronous API surface from our asyncio surface. """ -# This map defines replacements for asyncio API calls -asynciomap = { - "sleep": ({"time": time}, "time.sleep"), - "Queue": ({"queue": queue}, "queue.Queue"), - "Condition": ({"threading": threading}, "threading.Condition"), - "Future": ({"concurrent.futures": concurrent.futures}, "concurrent.futures.Future"), -} - -# This map defines find/replace pairs for the generated code and docstrings -# replace async calls with corresponding sync ones -name_map = { - "__anext__": "__next__", - "__aiter__": "__iter__", - "__aenter__": "__enter__", - "__aexit__": "__exit__", - "aclose": "close", - "AsyncIterable": "Iterable", - "AsyncIterator": "Iterator", - "AsyncGenerator": "Generator", - "StopAsyncIteration": "StopIteration", - "BigtableAsyncClient": "BigtableClient", - "AsyncRetry": "Retry", - "PooledBigtableGrpcAsyncIOTransport": "BigtableGrpcTransport", - "Awaitable": None, - "pytest_asyncio": "pytest", - "AsyncMock": "mock.Mock", - "_ReadRowsOperation": "_ReadRowsOperation_Sync", - "Table": "Table_Sync", - "BigtableDataClient": "BigtableDataClient_Sync", - "ReadRowsIterator": "ReadRowsIterator_Sync", - "_MutateRowsOperation": "_MutateRowsOperation_Sync", - "MutationsBatcher": "MutationsBatcher_Sync", - "_FlowControl": "_FlowControl_Sync", -} - -# This maps classes to the final sync surface location, so they can be instantiated in generated code -concrete_class_map = { - "_ReadRowsOperation": "google.cloud.bigtable._sync._concrete._ReadRowsOperation_Sync_Concrete", - "Table": "google.cloud.bigtable._sync._concrete.Table_Sync_Concrete", - "BigtableDataClient": "google.cloud.bigtable._sync._concrete.BigtableDataClient_Sync_Concrete", - "ReadRowsIterator": "google.cloud.bigtable._sync._concrete.ReadRowsIterator_Sync_Concrete", - "_MutateRowsOperation": "google.cloud.bigtable._sync._concrete._MutateRowsOperation_Sync_Concrete", - "MutationsBatcher": "google.cloud.bigtable._sync._concrete.MutationsBatcher_Threaded", - "_FlowControl": "google.cloud.bigtable._sync._concrete._FlowControl_Sync_Concrete", -} - -# This map defines import replacements for the generated code -# Note that "import ... as" statements keep the same name -import_map = { - ("google.api_core", "retry_async"): ("google.api_core", "retry"), - ( - "google.cloud.bigtable_v2.services.bigtable.async_client", - "BigtableAsyncClient", - ): ("google.cloud.bigtable_v2.services.bigtable.client", "BigtableClient"), - ("typing", "AsyncIterable"): ("typing", "Iterable"), - ("typing", "AsyncIterator"): ("typing", "Iterator"), - ("typing", "AsyncGenerator"): ("typing", "Generator"), - ( - "google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio", - "PooledBigtableGrpcAsyncIOTransport", - ): ( - "google.cloud.bigtable_v2.services.bigtable.transports.grpc", - "BigtableGrpcTransport", - ), - ("grpc.aio", "Channel"): ("grpc", "Channel"), -} - -# methods that are replaced with an empty implementation in sync surface -pass_methods=["__init__async__", "_prepare_stream", "_manage_channel", "_register_instance", "__init__transport__", "start_background_channel_refresh", "_start_idle_timer"] -# methods that are dropped from sync surface -drop_methods=["_buffer_to_generator", "_generator_to_buffer", "_idle_timeout_coroutine"] -# methods that raise a NotImplementedError in sync surface -error_methods=[] - class AsyncToSyncTransformer(ast.NodeTransformer): """ This class is used to transform async classes into sync classes. @@ -459,65 +385,15 @@ def generate_full_surface(save_path=None): """ Generate a sync surface from all async classes """ - from google.cloud.bigtable._read_rows import _ReadRowsOperation - from google.cloud.bigtable._mutate_rows import _MutateRowsOperation - from google.cloud.bigtable.client import Table - from google.cloud.bigtable.client import BigtableDataClient - from google.cloud.bigtable.iterators import ReadRowsIterator - from google.cloud.bigtable.mutations_batcher import MutationsBatcher - from google.cloud.bigtable.mutations_batcher import _FlowControl + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - conversion_list = [_ReadRowsOperation, Table, BigtableDataClient, ReadRowsIterator, _MutateRowsOperation, MutationsBatcher, _FlowControl] - code = transform_sync(conversion_list, - concrete_class_map=concrete_class_map, name_replacements=name_map, asyncio_replacements=asynciomap, import_replacements=import_map, - pass_methods=pass_methods, drop_methods=drop_methods, error_methods=error_methods, - add_imports=["import google.cloud.bigtable.exceptions as bt_exceptions"], - ) + conversion_list = [_ReadRowsOperationAsync] + code = transform_sync(conversion_list) if save_path is not None: with open(save_path, "w") as f: f.write(code) return code -def generate_system_tests(save_path=None): - from tests.system import test_system - conversion_list = [test_system] - code = transform_sync(conversion_list, - concrete_class_map=concrete_class_map, name_replacements=name_map, asyncio_replacements=asynciomap, import_replacements=import_map, - drop_methods=["test_read_rows_stream_inactive_timer"], - add_imports=["import google.cloud.bigtable"] - ) - if save_path is not None: - with open(save_path, "w") as f: - f.write(code) - return code - -def generate_unit_tests(test_path="./tests/unit", save_path=None): - """ - Unit tests should typically not be generated using this script. - But this is useful to generate a starting point. - """ - import importlib - if save_path is None: - save_path = os.path.join(test_path, "sync") - updated_list = [] - # find files in test_path - conversion_list = [f for f in os.listdir(test_path) if f.endswith(".py")] - # attempt tp convert each file - for f in conversion_list: - old_code = open(os.path.join(test_path, f), "r").read() - obj = importlib.import_module(f"tests.unit.{f[:-3]}") - new_code = transform_sync([obj], - concrete_class_map=concrete_class_map, name_replacements=name_map, asyncio_replacements=asynciomap, import_replacements=import_map, - add_imports=["import google.cloud.bigtable"] - ) - # only save files with async code - if any([a in new_code for a in asynciomap]) or "async def" in old_code: - with open(os.path.join(save_path, f), "w") as out: - out.write(new_code) - updated_list.append(f) - print(f"Updated {len(updated_list)} files: {updated_list}") - return updated_list if __name__ == "__main__": - generate_full_surface(save_path="./google/cloud/bigtable/_sync/_autogen.py") - generate_system_tests("./tests/system/test_system_sync_autogen.py") + generate_full_surface(save_path="./google/cloud/bigtable/data/_sync/_autogen.py") From e7889334d857dbee94c54a3312bbb162cc4ad4d3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Apr 2024 13:39:45 -0700 Subject: [PATCH 003/360] refactoring --- sync_surface_generator.py | 111 +++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 50 deletions(-) diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 47c3075a1..d571e5902 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -23,6 +23,7 @@ import os import threading import concurrent.futures +import importlib from black import format_str, FileMode import autoflake @@ -301,48 +302,52 @@ def get_imports(self, filename): ) return imports +def transform_class(in_obj: Type, new_name_format="{}_Sync", **kwargs): + filename = inspect.getfile(in_obj) + lines, lineno = inspect.getsourcelines(in_obj) + ast_tree = ast.parse(textwrap.dedent("".join(lines)), filename) + if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): + # update name + old_name = ast_tree.body[0].name + new_name = new_name_format.format(old_name) + ast_tree.body[0].name = new_name + ast.increment_lineno(ast_tree, lineno - 1) + # add ABC as base class + ast_tree.body[0].bases = ast_tree.body[0].bases + [ + ast.Name("ABC", ast.Load()), + ] + # remove top-level imports if any. Add them back later + ast_tree.body = [n for n in ast_tree.body if not isinstance(n, (ast.Import, ast.ImportFrom))] + # transform + transformer = AsyncToSyncTransformer(**kwargs) + transformer.visit(ast_tree) + # find imports + imports = transformer.get_imports(filename) + imports.add(ast.parse("from abc import ABC").body[0]) + # add globals + for g in transformer.globals: + imports.add(ast.parse(f"import {g}").body[0]) + # add locals from file, in case they are needed + if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): + file_basename = os.path.splitext(os.path.basename(filename))[0] + with open(filename, "r") as f: + for node in ast.walk(ast.parse(f.read(), filename)): + if isinstance(node, ast.ClassDef): + imports.add( + ast.parse( + f"from google.cloud.bigtable.{file_basename} import {node.name}" + ).body[0] + ) + return ast_tree.body, imports + -def transform_sync(class_list:list[Type], new_name_format="{}_Sync", add_imports=None, **kwargs): +def transform_all(class_list:list[Type], new_name_format="{}_Sync", add_imports=None, **kwargs): combined_tree = ast.parse("") combined_imports = set() for in_obj in class_list: - filename = inspect.getfile(in_obj) - lines, lineno = inspect.getsourcelines(in_obj) - ast_tree = ast.parse(textwrap.dedent("".join(lines)), filename) - if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): - # update name - old_name = ast_tree.body[0].name - new_name = new_name_format.format(old_name) - ast_tree.body[0].name = new_name - ast.increment_lineno(ast_tree, lineno - 1) - # add ABC as base class - ast_tree.body[0].bases = ast_tree.body[0].bases + [ - ast.Name("ABC", ast.Load()), - ] - # remove top-level imports if any. Add them back later - ast_tree.body = [n for n in ast_tree.body if not isinstance(n, (ast.Import, ast.ImportFrom))] - # transform - transformer = AsyncToSyncTransformer(**kwargs) - transformer.visit(ast_tree) - # find imports - imports = transformer.get_imports(filename) - imports.add(ast.parse("from abc import ABC").body[0]) - # add globals - for g in transformer.globals: - imports.add(ast.parse(f"import {g}").body[0]) - # add locals from file, in case they are needed - if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): - file_basename = os.path.splitext(os.path.basename(filename))[0] - with open(filename, "r") as f: - for node in ast.walk(ast.parse(f.read(), filename)): - if isinstance(node, ast.ClassDef): - imports.add( - ast.parse( - f"from google.cloud.bigtable.{file_basename} import {node.name}" - ).body[0] - ) + tree_body, imports = transform_class(in_obj, new_name_format, **kwargs) # update combined data - combined_tree.body.extend(ast_tree.body) + combined_tree.body.extend(tree_body) combined_imports.update(imports) # add extra imports if add_imports: @@ -381,19 +386,25 @@ def transform_sync(class_list:list[Type], new_name_format="{}_Sync", add_imports return formatted_code -def generate_full_surface(save_path=None): - """ - Generate a sync surface from all async classes - """ - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - conversion_list = [_ReadRowsOperationAsync] - code = transform_sync(conversion_list) - if save_path is not None: - with open(save_path, "w") as f: - f.write(code) - return code +if __name__ == "__main__": + classes = [ + { + "path": "google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync", + "drop_methods": ["read_rows"], + } + ] + + save_path = "./google/cloud/bigtable/data/_sync/_autogen.py" + + for class_dict in classes: + # convert string class path into class object + module_path, class_name = class_dict["path"].rsplit(".", 1) + class_object = getattr(importlib.import_module(module_path), class_name) + + code = transform_all([class_object]) + if save_path is not None: + with open(save_path, "w") as f: + f.write(code) -if __name__ == "__main__": - generate_full_surface(save_path="./google/cloud/bigtable/data/_sync/_autogen.py") From 7c740efb3881ca06b4f0d68b8c6780fc6b5f1b9e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Apr 2024 14:06:11 -0700 Subject: [PATCH 004/360] adding better support for yaml config --- sync_surface_generator.py | 54 ++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/sync_surface_generator.py b/sync_surface_generator.py index d571e5902..22f1d34ed 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -302,15 +302,15 @@ def get_imports(self, filename): ) return imports -def transform_class(in_obj: Type, new_name_format="{}_Sync", **kwargs): +def transform_class(in_obj: Type, **kwargs): filename = inspect.getfile(in_obj) lines, lineno = inspect.getsourcelines(in_obj) ast_tree = ast.parse(textwrap.dedent("".join(lines)), filename) if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): # update name old_name = ast_tree.body[0].name - new_name = new_name_format.format(old_name) - ast_tree.body[0].name = new_name + # set default name for new class if unset + ast_tree.body[0].name = kwargs.pop("autogen_sync_name", f"{old_name}_SyncGen") ast.increment_lineno(ast_tree, lineno - 1) # add ABC as base class ast_tree.body[0].bases = ast_tree.body[0].bases + [ @@ -341,18 +341,24 @@ def transform_class(in_obj: Type, new_name_format="{}_Sync", **kwargs): return ast_tree.body, imports -def transform_all(class_list:list[Type], new_name_format="{}_Sync", add_imports=None, **kwargs): +def transform_from_config(config_dict: dict): + # initialize new tree and import list combined_tree = ast.parse("") combined_imports = set() - for in_obj in class_list: - tree_body, imports = transform_class(in_obj, new_name_format, **kwargs) + # process each class + for class_dict in config_dict["classes"]: + # convert string class path into class object + module_path, class_name = class_dict.pop("path").rsplit(".", 1) + class_object = getattr(importlib.import_module(module_path), class_name) + # transform class + tree_body, imports = transform_class(class_object, **class_dict) # update combined data combined_tree.body.extend(tree_body) combined_imports.update(imports) # add extra imports - if add_imports: - for import_str in add_imports: - combined_imports.add(ast.parse(import_str).body[0]) + # if add_imports: + # for import_str in add_imports: + # combined_imports.add(ast.parse(import_str).body[0]) # render tree as string of code import_unique = list(set([ast.unparse(i) for i in combined_imports])) import_unique.sort() @@ -386,25 +392,25 @@ def transform_all(class_list:list[Type], new_name_format="{}_Sync", add_imports= return formatted_code - if __name__ == "__main__": - classes = [ - { - "path": "google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync", - "drop_methods": ["read_rows"], - } - ] + config = { + "classes": [ + { + "path": "google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync", + "autogen_sync_name": "_ReadRowsOperation_SyncGen", + "pass_methods": ["start_operation"], + "drop_methods": [], + "error_methods": [], + } + ] + } save_path = "./google/cloud/bigtable/data/_sync/_autogen.py" - for class_dict in classes: - # convert string class path into class object - module_path, class_name = class_dict["path"].rsplit(".", 1) - class_object = getattr(importlib.import_module(module_path), class_name) + code = transform_from_config(config) - code = transform_all([class_object]) - if save_path is not None: - with open(save_path, "w") as f: - f.write(code) + if save_path is not None: + with open(save_path, "w") as f: + f.write(code) From ff0458dc0f8d8b9f882392b3d952e75ea3433ad4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Apr 2024 14:43:17 -0700 Subject: [PATCH 005/360] updated import maps --- sync_surface_generator.py | 42 +++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 22f1d34ed..ba189a7e4 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -39,12 +39,11 @@ class AsyncToSyncTransformer(ast.NodeTransformer): outside of this autogeneration system are always applied """ - def __init__(self, *, import_replacements=None, asyncio_replacements=None, name_replacements=None, concrete_class_map=None, drop_methods=None, pass_methods=None, error_methods=None): + def __init__(self, *, import_replacements=None, name_replacements=None, concrete_class_map=None, drop_methods=None, pass_methods=None, error_methods=None): """ Args: - import_replacements: dict of (module, name) to (module, name) replacement import statements For example, {("foo", "bar"): ("baz", "qux")} will replace "from foo import bar" with "from baz import qux" - - asyncio_replacements: dict of asyncio function names to the module/function name to replace them with - name_replacements: dict of names to replace directly in the source code and docstrings - concrete_class_map: dict of the concrete class names for all autogenerated classes in this module, so that concrete versions of each class can be instantiated @@ -52,9 +51,7 @@ def __init__(self, *, import_replacements=None, asyncio_replacements=None, name_ - pass_methods: list of method names to replace with "pass" in the class - error_methods: list of method names to replace with "raise NotImplementedError" in the class """ - self.globals = {} self.import_replacements = import_replacements or {} - self.asyncio_replacements = asyncio_replacements or {} self.name_replacements = name_replacements or {} self.concrete_class_map = concrete_class_map or {} self.drop_methods = drop_methods or [] @@ -106,7 +103,7 @@ def visit_AsyncFunctionDef(self, node): and isinstance(n.func, ast.Attribute) and isinstance(n.func.value, ast.Name) and n.func.value.id == "asyncio" - and n.func.attr not in self.asyncio_replacements + and f"asyncio,{n.func.attr}" not in self.asyncio_replacements for n in ast.walk(func_ast) ) if has_asyncio_calls: @@ -157,10 +154,9 @@ def visit_Attribute(self, node): isinstance(node.value, ast.Name) and isinstance(node.value.ctx, ast.Load) and node.value.id == "asyncio" - and node.attr in self.asyncio_replacements + and f"asyncio.{node.attr}" in self.import_replacements ): - g, replacement = self.asyncio_replacements[node.attr] - self.globals.update(g) + replacement = self.import_replacements[f"asyncio.{node.attr}"] return ast.copy_location(ast.parse(replacement, mode="eval").body, node) elif isinstance(node, ast.Attribute) and node.attr in self.name_replacements: new_node = ast.copy_location( @@ -283,14 +279,15 @@ def get_imports(self, filename): for alias in node.names: if isinstance(node, ast.Import): # import statments - new_import = self.import_replacements.get((alias.name), (alias.name)) + new_import = self.import_replacements.get(alias.name, alias.name) imports.add(ast.parse(f"import {new_import}").body[0]) else: # import from statements # break into individual components - module, name = self.import_replacements.get( - (node.module, alias.name), (node.module, alias.name) - ) + full_path = f"{node.module}.{alias.name}" + if full_path in self.import_replacements: + full_path = self.import_replacements[full_path] + module, name = full_path.rsplit(".", 1) # don't import from same file if module == ".": continue @@ -325,8 +322,8 @@ def transform_class(in_obj: Type, **kwargs): imports = transformer.get_imports(filename) imports.add(ast.parse("from abc import ABC").body[0]) # add globals - for g in transformer.globals: - imports.add(ast.parse(f"import {g}").body[0]) + # for g in transformer.globals: + # imports.add(ast.parse(f"import {g}").body[0]) # add locals from file, in case they are needed if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): file_basename = os.path.splitext(os.path.basename(filename))[0] @@ -350,6 +347,8 @@ def transform_from_config(config_dict: dict): # convert string class path into class object module_path, class_name = class_dict.pop("path").rsplit(".", 1) class_object = getattr(importlib.import_module(module_path), class_name) + # add globals to class_dict + class_dict["import_replacements"] = {**config_dict.get("import_replacements", {}), **class_dict.get("import_replacements", {})} # transform class tree_body, imports = transform_class(class_object, **class_dict) # update combined data @@ -394,11 +393,24 @@ def transform_from_config(config_dict: dict): if __name__ == "__main__": config = { + "import_replacements": { + # "asyncio.sleep": "time.sleep", + # "asyncio.Queue": "queue.Queue", + # "asyncio.Condition": "threading.Condition", + # "asyncio.Future": "concurrent.futures.Future", + "google.api_core.retry_async": "google.api_core.retry", + "typing.AsyncIterable": "typing.Iterable", + "typing.AsyncIterator": "typing.Iterator", + "typing.AsyncGenerator": "typing.Generator", + "grpc.aio.Channel": "grpc.Channel", + "google.cloud.bigtable_v2.services.bigtable.async_client.BigtableAsyncClient": "google.cloud.bigtable_v2.services.bigtable.client.BigtableClient", + "google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio.PooledBigtableGrpcAsyncIOTransport": "google.cloud.bigtable_v2.services.bigtable.transports.BigtableGrpcTransport", + }, # replace imports with corresponding sync version. Does not touch the code, only import lines "classes": [ { "path": "google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync", "autogen_sync_name": "_ReadRowsOperation_SyncGen", - "pass_methods": ["start_operation"], + "pass_methods": [], "drop_methods": [], "error_methods": [], } From fcf2a2f3f6cee94b8b7a3cfa3379e34b5633e459 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Apr 2024 14:50:31 -0700 Subject: [PATCH 006/360] added back name_replacements --- sync_surface_generator.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/sync_surface_generator.py b/sync_surface_generator.py index ba189a7e4..e29d69633 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -60,7 +60,7 @@ def __init__(self, *, import_replacements=None, name_replacements=None, concrete def update_docstring(self, docstring): """ - Update docstring toreplace any key words in the name_replacements dict + Update docstring to replace any key words in the name_replacements dict """ if not docstring: return docstring @@ -103,7 +103,7 @@ def visit_AsyncFunctionDef(self, node): and isinstance(n.func, ast.Attribute) and isinstance(n.func.value, ast.Name) and n.func.value.id == "asyncio" - and f"asyncio,{n.func.attr}" not in self.asyncio_replacements + and f"asyncio.{n.func.attr}" not in self.import_replacements for n in ast.walk(func_ast) ) if has_asyncio_calls: @@ -349,6 +349,7 @@ def transform_from_config(config_dict: dict): class_object = getattr(importlib.import_module(module_path), class_name) # add globals to class_dict class_dict["import_replacements"] = {**config_dict.get("import_replacements", {}), **class_dict.get("import_replacements", {})} + class_dict["name_replacements"] = {**config_dict.get("name_replacements", {}), **class_dict.get("name_replacements", {})} # transform class tree_body, imports = transform_class(class_object, **class_dict) # update combined data @@ -403,9 +404,33 @@ def transform_from_config(config_dict: dict): "typing.AsyncIterator": "typing.Iterator", "typing.AsyncGenerator": "typing.Generator", "grpc.aio.Channel": "grpc.Channel", - "google.cloud.bigtable_v2.services.bigtable.async_client.BigtableAsyncClient": "google.cloud.bigtable_v2.services.bigtable.client.BigtableClient", - "google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio.PooledBigtableGrpcAsyncIOTransport": "google.cloud.bigtable_v2.services.bigtable.transports.BigtableGrpcTransport", + # "google.cloud.bigtable_v2.services.bigtable.async_client.BigtableAsyncClient": "google.cloud.bigtable_v2.services.bigtable.client.BigtableClient", + # "google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio.PooledBigtableGrpcAsyncIOTransport": "google.cloud.bigtable_v2.services.bigtable.transports.BigtableGrpcTransport", }, # replace imports with corresponding sync version. Does not touch the code, only import lines + "name_replacements": { + "__anext__": "__next__", + "__aiter__": "__iter__", + "__aenter__": "__enter__", + "__aexit__": "__exit__", + "aclose": "close", + "AsyncIterable": "Iterable", + "AsyncIterator": "Iterator", + "AsyncGenerator": "Generator", + "StopAsyncIteration": "StopIteration", + "BigtableAsyncClient": "BigtableClient", + "AsyncRetry": "Retry", + "PooledBigtableGrpcAsyncIOTransport": "BigtableGrpcTransport", + "Awaitable": None, + "pytest_asyncio": "pytest", + "AsyncMock": "mock.Mock", + # "_ReadRowsOperation": "_ReadRowsOperation_Sync", + # "Table": "Table_Sync", + # "BigtableDataClient": "BigtableDataClient_Sync", + # "ReadRowsIterator": "ReadRowsIterator_Sync", + # "_MutateRowsOperation": "_MutateRowsOperation_Sync", + # "MutationsBatcher": "MutationsBatcher_Sync", + # "_FlowControl": "_FlowControl_Sync", + }, # performs find/replace for these terms in docstrings and generated code "classes": [ { "path": "google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync", From 51d9c7fecd636a39588f4c0bf8aa1e1d0c229dc3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Apr 2024 15:01:47 -0700 Subject: [PATCH 007/360] added configs for all async classes --- sync_surface_generator.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/sync_surface_generator.py b/sync_surface_generator.py index e29d69633..d4b368f75 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -438,7 +438,27 @@ def transform_from_config(config_dict: dict): "pass_methods": [], "drop_methods": [], "error_methods": [], - } + }, + { + "path": "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync", + "autogen_sync_name": "_MutateRowsOperation_SyncGen", + }, + { + "path": "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync", + "autogen_sync_name": "MutationsBatcher_SyncGen", + }, + { + "path": "google.cloud.bigtable.data._async.mutations_batcher._FlowControlAsync", + "autogen_sync_name": "_FlowControl_SyncGen", + }, + { + "path": "google.cloud.bigtable.data._async.client.BigtableDataClientAsync", + "autogen_sync_name": "BigtableDataClient_SyncGen", + }, + { + "path": "google.cloud.bigtable.data._async.client.TableAsync", + "autogen_sync_name": "Table_SyncGen", + }, ] } From 3fefea49a40b9c6018b544e5f799dd73be9edd03 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Apr 2024 15:11:16 -0700 Subject: [PATCH 008/360] removed concrete class map --- sync_surface_generator.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sync_surface_generator.py b/sync_surface_generator.py index d4b368f75..d563f5c5b 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -39,21 +39,18 @@ class AsyncToSyncTransformer(ast.NodeTransformer): outside of this autogeneration system are always applied """ - def __init__(self, *, import_replacements=None, name_replacements=None, concrete_class_map=None, drop_methods=None, pass_methods=None, error_methods=None): + def __init__(self, *, import_replacements=None, name_replacements=None, drop_methods=None, pass_methods=None, error_methods=None): """ Args: - import_replacements: dict of (module, name) to (module, name) replacement import statements For example, {("foo", "bar"): ("baz", "qux")} will replace "from foo import bar" with "from baz import qux" - name_replacements: dict of names to replace directly in the source code and docstrings - - concrete_class_map: dict of the concrete class names for all autogenerated classes - in this module, so that concrete versions of each class can be instantiated - drop_methods: list of method names to drop from the class - pass_methods: list of method names to replace with "pass" in the class - error_methods: list of method names to replace with "raise NotImplementedError" in the class """ self.import_replacements = import_replacements or {} self.name_replacements = name_replacements or {} - self.concrete_class_map = concrete_class_map or {} self.drop_methods = drop_methods or [] self.pass_methods = pass_methods or [] self.error_methods = error_methods or [] @@ -134,9 +131,6 @@ def visit_Call(self, node): node.func.value, ast.Name ): node.func.value.id = self.name_replacements.get(node.func.value.id, node.func.value.id) - # when initializing an auto-generated sync class, replace the class name with the patched version - if isinstance(node.func, ast.Name) and node.func.id in self.concrete_class_map: - node.func.id = self.concrete_class_map[node.func.id] return ast.copy_location( ast.Call( self.visit(node.func), From b6a89dd56a41153dadd75a109a27abd633c46063 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Apr 2024 15:37:33 -0700 Subject: [PATCH 009/360] improced handling of unexpected asyncio tasks --- sync_surface_generator.py | 46 +++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/sync_surface_generator.py b/sync_surface_generator.py index d563f5c5b..77da234f6 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -39,9 +39,10 @@ class AsyncToSyncTransformer(ast.NodeTransformer): outside of this autogeneration system are always applied """ - def __init__(self, *, import_replacements=None, name_replacements=None, drop_methods=None, pass_methods=None, error_methods=None): + def __init__(self, *, name=None, import_replacements=None, name_replacements=None, drop_methods=None, pass_methods=None, error_methods=None): """ Args: + - name: the name of the class being processed. Just used in exceptions - import_replacements: dict of (module, name) to (module, name) replacement import statements For example, {("foo", "bar"): ("baz", "qux")} will replace "from foo import bar" with "from baz import qux" - name_replacements: dict of names to replace directly in the source code and docstrings @@ -49,6 +50,7 @@ def __init__(self, *, import_replacements=None, name_replacements=None, drop_met - pass_methods: list of method names to replace with "pass" in the class - error_methods: list of method names to replace with "raise NotImplementedError" in the class """ + self.name = name self.import_replacements = import_replacements or {} self.name_replacements = name_replacements or {} self.drop_methods = drop_methods or [] @@ -95,19 +97,16 @@ def visit_AsyncFunctionDef(self, node): else: # check if the function contains non-replaced usage of asyncio func_ast = ast.parse(ast.unparse(node)) - has_asyncio_calls = any( - isinstance(n, ast.Call) - and isinstance(n.func, ast.Attribute) - and isinstance(n.func.value, ast.Name) - and n.func.value.id == "asyncio" - and f"asyncio.{n.func.attr}" not in self.import_replacements - for n in ast.walk(func_ast) - ) - if has_asyncio_calls: - self._create_error_node( - node, - "Corresponding Async Function contains unhandled asyncio calls", - ) + for n in ast.walk(func_ast): + if isinstance(n, ast.Call) \ + and isinstance(n.func, ast.Attribute) \ + and isinstance(n.func.value, ast.Name) \ + and n.func.value.id == "asyncio" \ + and f"asyncio.{n.func.attr}" not in self.import_replacements: + path_str = f"{self.name}.{node.name}" if self.name else node.name + raise RuntimeError( + f"{path_str} contains unhandled asyncio calls: {n.func.attr}. Add method to drop_methods, pass_methods, or error_methods to handle safely." + ) # remove pytest.mark.asyncio decorator if hasattr(node, "decorator_list"): is_asyncio_decorator = lambda d: all(x in ast.dump(d) for x in ["pytest", "mark", "asyncio"]) @@ -297,11 +296,13 @@ def transform_class(in_obj: Type, **kwargs): filename = inspect.getfile(in_obj) lines, lineno = inspect.getsourcelines(in_obj) ast_tree = ast.parse(textwrap.dedent("".join(lines)), filename) + new_name = None if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): # update name old_name = ast_tree.body[0].name # set default name for new class if unset - ast_tree.body[0].name = kwargs.pop("autogen_sync_name", f"{old_name}_SyncGen") + new_name = kwargs.pop("autogen_sync_name", f"{old_name}_SyncGen") + ast_tree.body[0].name = new_name ast.increment_lineno(ast_tree, lineno - 1) # add ABC as base class ast_tree.body[0].bases = ast_tree.body[0].bases + [ @@ -310,7 +311,7 @@ def transform_class(in_obj: Type, **kwargs): # remove top-level imports if any. Add them back later ast_tree.body = [n for n in ast_tree.body if not isinstance(n, (ast.Import, ast.ImportFrom))] # transform - transformer = AsyncToSyncTransformer(**kwargs) + transformer = AsyncToSyncTransformer(name=new_name, **kwargs) transformer.visit(ast_tree) # find imports imports = transformer.get_imports(filename) @@ -364,7 +365,7 @@ def transform_from_config(config_dict: dict): non_google.append(i) import_str = "\n".join(non_google + [""] + google) # append clean tree - header = """# Copyright 2023 Google LLC + header = """# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -389,10 +390,10 @@ def transform_from_config(config_dict: dict): if __name__ == "__main__": config = { "import_replacements": { - # "asyncio.sleep": "time.sleep", - # "asyncio.Queue": "queue.Queue", - # "asyncio.Condition": "threading.Condition", - # "asyncio.Future": "concurrent.futures.Future", + "asyncio.sleep": "time.sleep", + "asyncio.Queue": "queue.Queue", + "asyncio.Condition": "threading.Condition", + "asyncio.Future": "concurrent.futures.Future", "google.api_core.retry_async": "google.api_core.retry", "typing.AsyncIterable": "typing.Iterable", "typing.AsyncIterator": "typing.Iterator", @@ -440,6 +441,7 @@ def transform_from_config(config_dict: dict): { "path": "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync", "autogen_sync_name": "MutationsBatcher_SyncGen", + "pass_methods": ["_start_flush_timer", "close", "_create_bg_task", "_wait_for_batch_results"] }, { "path": "google.cloud.bigtable.data._async.mutations_batcher._FlowControlAsync", @@ -448,10 +450,12 @@ def transform_from_config(config_dict: dict): { "path": "google.cloud.bigtable.data._async.client.BigtableDataClientAsync", "autogen_sync_name": "BigtableDataClient_SyncGen", + "pass_methods": ["_start_background_channel_refresh", "close", "_ping_and_warm_instances"] }, { "path": "google.cloud.bigtable.data._async.client.TableAsync", "autogen_sync_name": "Table_SyncGen", + "pass_methods": ["__init__", "read_rows_sharded"] }, ] } From 4d18e58d6606a0f3e8ceeb418a60154df96a9c9b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Apr 2024 16:02:27 -0700 Subject: [PATCH 010/360] added concrete classes --- google/cloud/bigtable/data/__init__.py | 10 +- google/cloud/bigtable/data/_sync/_autogen.py | 1765 +++++++++++++++++ .../cloud/bigtable/data/_sync/_mutate_rows.py | 22 + .../cloud/bigtable/data/_sync/_read_rows.py | 22 + google/cloud/bigtable/data/_sync/client.py | 45 + .../bigtable/data/_sync/mutations_batcher.py | 26 + .../cloud/bigtable/data/_sync/sync_gen.yaml | 3 + 7 files changed, 1892 insertions(+), 1 deletion(-) create mode 100644 google/cloud/bigtable/data/_sync/_autogen.py create mode 100644 google/cloud/bigtable/data/_sync/_mutate_rows.py create mode 100644 google/cloud/bigtable/data/_sync/_read_rows.py create mode 100644 google/cloud/bigtable/data/_sync/client.py create mode 100644 google/cloud/bigtable/data/_sync/mutations_batcher.py create mode 100644 google/cloud/bigtable/data/_sync/sync_gen.yaml diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index 5229f8021..fd44fe86c 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -20,6 +20,11 @@ from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync +from google.cloud.bigtable.data._sync.client import BigtableDataClient +from google.cloud.bigtable.data._sync.client import Table + +from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.read_rows_query import RowRange from google.cloud.bigtable.data.row import Row @@ -48,12 +53,15 @@ __version__: str = package_version.__version__ __all__ = ( + "BigtableDataClient", + "Table", + "MutationsBatcher", "BigtableDataClientAsync", "TableAsync", + "MutationsBatcherAsync", "RowKeySamples", "ReadRowsQuery", "RowRange", - "MutationsBatcherAsync", "Mutation", "RowMutationEntry", "SetCell", diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py new file mode 100644 index 000000000..58abd4cd9 --- /dev/null +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -0,0 +1,1765 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by sync_surface_generator.py. Do not edit. + + +from __future__ import annotations +from abc import ABC +from collections import deque +from functools import partial +from typing import Any +from typing import Generator +from typing import Iterable +from typing import Optional +from typing import Sequence +from typing import Set +from typing import cast +import asyncio +import atexit +import functools +import grpc +import os +import random +import time +import warnings + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core.exceptions import Aborted +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import ServiceUnavailable +from google.api_core.retry import exponential_sleep_generator +from google.cloud.bigtable._mutate_rows import _EntryWithProto +from google.cloud.bigtable._mutate_rows import _MutateRowsOperationAsync +from google.cloud.bigtable._read_rows import _ReadRowsOperationAsync +from google.cloud.bigtable._read_rows import _ResetRow +from google.cloud.bigtable.client import BigtableDataClientAsync +from google.cloud.bigtable.client import TableAsync +from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT +from google.cloud.bigtable.data._async._mutate_rows import ( + _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, +) +from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync +from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync +from google.cloud.bigtable.data._async.client import TableAsync +from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync +from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE +from google.cloud.bigtable.data._helpers import RowKeySamples +from google.cloud.bigtable.data._helpers import ShardedQuery +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _WarmedInstanceKey +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import _get_retryable_errors +from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data.exceptions import FailedMutationEntryError +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup +from google.cloud.bigtable.data.exceptions import _RowSetComplete +from google.cloud.bigtable.data.mutations import Mutation +from google.cloud.bigtable.data.mutations import RowMutationEntry +from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.row import Cell +from google.cloud.bigtable.data.row import Row +from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter +from google.cloud.bigtable.data.row_filters import RowFilter +from google.cloud.bigtable.data.row_filters import RowFilterChain +from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter +from google.cloud.bigtable.mutations_batcher import MutationsBatcherAsync +from google.cloud.bigtable.mutations_batcher import _FlowControlAsync +from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient +from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO +from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel, +) +from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB +from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB +from google.cloud.bigtable_v2.types import RowRange as RowRangePB +from google.cloud.bigtable_v2.types import RowSet as RowSetPB +from google.cloud.client import ClientWithProject +from google.cloud.environment_vars import BIGTABLE_EMULATOR +import google.auth._default +import google.auth.credentials +import google.cloud.bigtable.data.exceptions +import google.cloud.bigtable_v2.types.bigtable + + +class _ReadRowsOperation_SyncGen(ABC): + """ + ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream + into a stream of Row objects. + + ReadRowsOperation.merge_row_response_stream takes in a stream of ReadRowsResponse + and turns them into a stream of Row objects using an internal + StateMachine. + + ReadRowsOperation(request, client) handles row merging logic end-to-end, including + performing retries on stream errors. + """ + + __slots__ = ( + "attempt_timeout_gen", + "operation_timeout", + "request", + "table", + "_predicate", + "_metadata", + "_last_yielded_row_key", + "_remaining_count", + ) + + def __init__( + self, + query: ReadRowsQuery, + table: "TableAsync", + operation_timeout: float, + attempt_timeout: float, + retryable_exceptions: Sequence[type[Exception]] = (), + ): + self.attempt_timeout_gen = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + self.operation_timeout = operation_timeout + if isinstance(query, dict): + self.request = ReadRowsRequestPB( + **query, + table_name=table.table_name, + app_profile_id=table.app_profile_id, + ) + else: + self.request = query._to_pb(table) + self.table = table + self._predicate = retries.if_exception_type(*retryable_exceptions) + self._metadata = _make_metadata(table.table_name, table.app_profile_id) + self._last_yielded_row_key: bytes | None = None + self._remaining_count: int | None = self.request.rows_limit or None + + def start_operation(self) -> Generator[Row, None, "Any"]: + """Start the read_rows operation, retrying on retryable errors.""" + return retries.retry_target_stream_async( + self._read_rows_attempt, + self._predicate, + exponential_sleep_generator(0.01, 60, multiplier=2), + self.operation_timeout, + exception_factory=_retry_exception_factory, + ) + + def _read_rows_attempt(self) -> Generator[Row, None, "Any"]: + """ + Attempt a single read_rows rpc call. + This function is intended to be wrapped by retry logic, + which will call this function until it succeeds or + a non-retryable error is raised. + """ + if self._last_yielded_row_key is not None: + try: + self.request.rows = self._revise_request_rowset( + row_set=self.request.rows, + last_seen_row_key=self._last_yielded_row_key, + ) + except _RowSetComplete: + return self.merge_rows(None) + if self._remaining_count is not None: + self.request.rows_limit = self._remaining_count + if self._remaining_count == 0: + return self.merge_rows(None) + gapic_stream = self.table.client._gapic_client.read_rows( + self.request, + timeout=next(self.attempt_timeout_gen), + metadata=self._metadata, + retry=None, + ) + chunked_stream = self.chunk_stream(gapic_stream) + return self.merge_rows(chunked_stream) + + def chunk_stream( + self, stream: Iterable[ReadRowsResponsePB] + ) -> Generator[ReadRowsResponsePB.CellChunk, None, "Any"]: + """process chunks out of raw read_rows stream""" + for resp in stream: + resp = resp._pb + if resp.last_scanned_row_key: + if ( + self._last_yielded_row_key is not None + and resp.last_scanned_row_key <= self._last_yielded_row_key + ): + raise InvalidChunk("last scanned out of order") + self._last_yielded_row_key = resp.last_scanned_row_key + current_key = None + for c in resp.chunks: + if current_key is None: + current_key = c.row_key + if current_key is None: + raise InvalidChunk("first chunk is missing a row key") + elif ( + self._last_yielded_row_key + and current_key <= self._last_yielded_row_key + ): + raise InvalidChunk("row keys should be strictly increasing") + yield c + if c.reset_row: + current_key = None + elif c.commit_row: + self._last_yielded_row_key = current_key + if self._remaining_count is not None: + self._remaining_count -= 1 + if self._remaining_count < 0: + raise InvalidChunk("emit count exceeds row limit") + current_key = None + + @staticmethod + def merge_rows(chunks: Generator[ReadRowsResponsePB.CellChunk, None, "Any"] | None): + """Merge chunks into rows""" + if chunks is None: + return + it = chunks.__iter__() + while True: + try: + c = it.__next__() + except StopIteration: + return + row_key = c.row_key + if not row_key: + raise InvalidChunk("first row chunk is missing key") + cells = [] + family: str | None = None + qualifier: bytes | None = None + try: + while True: + if c.reset_row: + raise _ResetRow(c) + k = c.row_key + f = c.family_name.value + q = c.qualifier.value if c.HasField("qualifier") else None + if k and k != row_key: + raise InvalidChunk("unexpected new row key") + if f: + family = f + if q is not None: + qualifier = q + else: + raise InvalidChunk("new family without qualifier") + elif family is None: + raise InvalidChunk("missing family") + elif q is not None: + if family is None: + raise InvalidChunk("new qualifier without family") + qualifier = q + elif qualifier is None: + raise InvalidChunk("missing qualifier") + ts = c.timestamp_micros + labels = c.labels if c.labels else [] + value = c.value + if c.value_size > 0: + buffer = [value] + while c.value_size > 0: + c = it.__next__() + t = c.timestamp_micros + cl = c.labels + k = c.row_key + if ( + c.HasField("family_name") + and c.family_name.value != family + ): + raise InvalidChunk("family changed mid cell") + if ( + c.HasField("qualifier") + and c.qualifier.value != qualifier + ): + raise InvalidChunk("qualifier changed mid cell") + if t and t != ts: + raise InvalidChunk("timestamp changed mid cell") + if cl and cl != labels: + raise InvalidChunk("labels changed mid cell") + if k and k != row_key: + raise InvalidChunk("row key changed mid cell") + if c.reset_row: + raise _ResetRow(c) + buffer.append(c.value) + value = b"".join(buffer) + cells.append( + Cell(value, row_key, family, qualifier, ts, list(labels)) + ) + if c.commit_row: + yield Row(row_key, cells) + break + c = it.__next__() + except _ResetRow as e: + c = e.chunk + if ( + c.row_key + or c.HasField("family_name") + or c.HasField("qualifier") + or c.timestamp_micros + or c.labels + or c.value + ): + raise InvalidChunk("reset row with data") + continue + except StopIteration: + raise InvalidChunk("premature end of stream") + + @staticmethod + def _revise_request_rowset(row_set: RowSetPB, last_seen_row_key: bytes) -> RowSetPB: + """ + Revise the rows in the request to avoid ones we've already processed. + + Args: + - row_set: the row set from the request + - last_seen_row_key: the last row key encountered + Raises: + - _RowSetComplete: if there are no rows left to process after the revision + """ + if row_set is None or (not row_set.row_ranges and row_set.row_keys is not None): + last_seen = last_seen_row_key + return RowSetPB(row_ranges=[RowRangePB(start_key_open=last_seen)]) + adjusted_keys: list[bytes] = [ + k for k in row_set.row_keys if k > last_seen_row_key + ] + adjusted_ranges: list[RowRangePB] = [] + for row_range in row_set.row_ranges: + end_key = row_range.end_key_closed or row_range.end_key_open or None + if end_key is None or end_key > last_seen_row_key: + new_range = RowRangePB(row_range) + start_key = row_range.start_key_closed or row_range.start_key_open + if start_key is None or start_key <= last_seen_row_key: + new_range.start_key_open = last_seen_row_key + adjusted_ranges.append(new_range) + if len(adjusted_keys) == 0 and len(adjusted_ranges) == 0: + raise _RowSetComplete() + return RowSetPB(row_keys=adjusted_keys, row_ranges=adjusted_ranges) + + +class _MutateRowsOperation_SyncGen(ABC): + """ + MutateRowsOperation manages the logic of sending a set of row mutations, + and retrying on failed entries. It manages this using the _run_attempt + function, which attempts to mutate all outstanding entries, and raises + _MutateRowsIncomplete if any retryable errors are encountered. + + Errors are exposed as a MutationsExceptionGroup, which contains a list of + exceptions organized by the related failed mutation entries. + """ + + def __init__( + self, + gapic_client: "BigtableAsyncClient", + table: "TableAsync", + mutation_entries: list["RowMutationEntry"], + operation_timeout: float, + attempt_timeout: float | None, + retryable_exceptions: Sequence[type[Exception]] = (), + ): + """ + Args: + - gapic_client: the client to use for the mutate_rows call + - table: the table associated with the request + - mutation_entries: a list of RowMutationEntry objects to send to the server + - operation_timeout: the timeout to use for the entire operation, in seconds. + - attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds. + If not specified, the request will run until operation_timeout is reached. + """ + total_mutations = sum((len(entry.mutations) for entry in mutation_entries)) + if total_mutations > _MUTATE_ROWS_REQUEST_MUTATION_LIMIT: + raise ValueError( + f"mutate_rows requests can contain at most {_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across all entries. Found {total_mutations}." + ) + metadata = _make_metadata(table.table_name, table.app_profile_id) + self._gapic_fn = functools.partial( + gapic_client.mutate_rows, + table_name=table.table_name, + app_profile_id=table.app_profile_id, + metadata=metadata, + retry=None, + ) + self.is_retryable = retries.if_exception_type( + *retryable_exceptions, bt_exceptions._MutateRowsIncomplete + ) + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + self._operation = retries.retry_target_async( + self._run_attempt, + self.is_retryable, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) + self.timeout_generator = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries] + self.remaining_indices = list(range(len(self.mutations))) + self.errors: dict[int, list[Exception]] = {} + + def start(self): + """ + Start the operation, and run until completion + + Raises: + - MutationsExceptionGroup: if any mutations failed + """ + try: + self._operation + except Exception as exc: + incomplete_indices = self.remaining_indices.copy() + for idx in incomplete_indices: + self._handle_entry_error(idx, exc) + finally: + all_errors: list[Exception] = [] + for idx, exc_list in self.errors.items(): + if len(exc_list) == 0: + raise core_exceptions.ClientError( + f"Mutation {idx} failed with no associated errors" + ) + elif len(exc_list) == 1: + cause_exc = exc_list[0] + else: + cause_exc = bt_exceptions.RetryExceptionGroup(exc_list) + entry = self.mutations[idx].entry + all_errors.append( + bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc) + ) + if all_errors: + raise bt_exceptions.MutationsExceptionGroup( + all_errors, len(self.mutations) + ) + + def _run_attempt(self): + """ + Run a single attempt of the mutate_rows rpc. + + Raises: + - _MutateRowsIncomplete: if there are failed mutations eligible for + retry after the attempt is complete + - GoogleAPICallError: if the gapic rpc fails + """ + request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] + active_request_indices = { + req_idx: orig_idx + for (req_idx, orig_idx) in enumerate(self.remaining_indices) + } + self.remaining_indices = [] + if not request_entries: + return + try: + result_generator = self._gapic_fn( + timeout=next(self.timeout_generator), + entries=request_entries, + retry=None, + ) + for result_list in result_generator: + for result in result_list.entries: + orig_idx = active_request_indices[result.index] + entry_error = core_exceptions.from_grpc_status( + result.status.code, + result.status.message, + details=result.status.details, + ) + if result.status.code != 0: + self._handle_entry_error(orig_idx, entry_error) + elif orig_idx in self.errors: + del self.errors[orig_idx] + del active_request_indices[result.index] + except Exception as exc: + for idx in active_request_indices.values(): + self._handle_entry_error(idx, exc) + raise + if self.remaining_indices: + raise bt_exceptions._MutateRowsIncomplete + + def _handle_entry_error(self, idx: int, exc: Exception): + """ + Add an exception to the list of exceptions for a given mutation index, + and add the index to the list of remaining indices if the exception is + retryable. + + Args: + - idx: the index of the mutation that failed + - exc: the exception to add to the list + """ + entry = self.mutations[idx].entry + self.errors.setdefault(idx, []).append(exc) + if ( + entry.is_idempotent() + and self.is_retryable(exc) + and (idx not in self.remaining_indices) + ): + self.remaining_indices.append(idx) + + +class MutationsBatcher_SyncGen(ABC): + """ + Allows users to send batches using context manager API: + + Runs mutate_row, mutate_rows, and check_and_mutate_row internally, combining + to use as few network requests as required + + Flushes: + - every flush_interval seconds + - after queue reaches flush_count in quantity + - after queue reaches flush_size_bytes in storage size + - when batcher is closed or destroyed + + async with table.mutations_batcher() as batcher: + for i in range(10): + batcher.add(row, mut) + """ + + def __init__( + self, + table: "TableAsync", + *, + flush_interval: float | None = 5, + flush_limit_mutation_count: int | None = 1000, + flush_limit_bytes: int = 20 * _MB_SIZE, + flow_control_max_mutation_count: int = 100000, + flow_control_max_bytes: int = 100 * _MB_SIZE, + batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ): + """ + Args: + - table: Table to preform rpc calls + - flush_interval: Automatically flush every flush_interval seconds. + If None, no time-based flushing is performed. + - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count + mutations are added across all entries. If None, this limit is ignored. + - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. + - flow_control_max_mutation_count: Maximum number of inflight mutations. + - flow_control_max_bytes: Maximum number of inflight bytes. + - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_operation_timeout. + - batch_attempt_timeout: timeout for each individual request, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to batch_operation_timeout. + - batch_retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors. + """ + (self._operation_timeout, self._attempt_timeout) = _get_timeouts( + batch_operation_timeout, batch_attempt_timeout, table + ) + self._retryable_errors: list[type[Exception]] = _get_retryable_errors( + batch_retryable_errors, table + ) + self.closed: bool = False + self._table = table + self._staged_entries: list[RowMutationEntry] = [] + (self._staged_count, self._staged_bytes) = (0, 0) + self._flow_control = _FlowControlAsync( + flow_control_max_mutation_count, flow_control_max_bytes + ) + self._flush_limit_bytes = flush_limit_bytes + self._flush_limit_count = ( + flush_limit_mutation_count + if flush_limit_mutation_count is not None + else float("inf") + ) + self._flush_timer = self._start_flush_timer(flush_interval) + self._flush_jobs: set[concurrent.futures.Future[None]] = set() + self._entries_processed_since_last_raise: int = 0 + self._exceptions_since_last_raise: int = 0 + self._exception_list_limit: int = 10 + self._oldest_exceptions: list[Exception] = [] + self._newest_exceptions: deque[Exception] = deque( + maxlen=self._exception_list_limit + ) + atexit.register(self._on_exit) + + def _start_flush_timer( + self, interval: float | None + ) -> concurrent.futures.Future[None]: + """Implementation purposely removed in sync mode""" + + def append(self, mutation_entry: RowMutationEntry): + """ + Add a new set of mutations to the internal queue + + TODO: return a future to track completion of this entry + + Args: + - mutation_entry: new entry to add to flush queue + Raises: + - RuntimeError if batcher is closed + - ValueError if an invalid mutation type is added + """ + if self.closed: + raise RuntimeError("Cannot append to closed MutationsBatcher") + if isinstance(mutation_entry, Mutation): + raise ValueError( + f"invalid mutation type: {type(mutation_entry).__name__}. Only RowMutationEntry objects are supported by batcher" + ) + self._staged_entries.append(mutation_entry) + self._staged_count += len(mutation_entry.mutations) + self._staged_bytes += mutation_entry.size() + if ( + self._staged_count >= self._flush_limit_count + or self._staged_bytes >= self._flush_limit_bytes + ): + self._schedule_flush() + time.sleep(0) + + def _schedule_flush(self) -> concurrent.futures.Future[None] | None: + """Update the flush task to include the latest staged entries""" + if self._staged_entries: + (entries, self._staged_entries) = (self._staged_entries, []) + (self._staged_count, self._staged_bytes) = (0, 0) + new_task = self._create_bg_task(self._flush_internal, entries) + new_task.add_done_callback(self._flush_jobs.remove) + self._flush_jobs.add(new_task) + return new_task + return None + + def _flush_internal(self, new_entries: list[RowMutationEntry]): + """ + Flushes a set of mutations to the server, and updates internal state + + Args: + - new_entries: list of RowMutationEntry objects to flush + """ + in_process_requests: list[ + concurrent.futures.Future[list[FailedMutationEntryError]] + ] = [] + for batch in self._flow_control.add_to_flow(new_entries): + batch_task = self._create_bg_task(self._execute_mutate_rows, batch) + in_process_requests.append(batch_task) + found_exceptions = self._wait_for_batch_results(*in_process_requests) + self._entries_processed_since_last_raise += len(new_entries) + self._add_exceptions(found_exceptions) + + def _execute_mutate_rows( + self, batch: list[RowMutationEntry] + ) -> list[FailedMutationEntryError]: + """ + Helper to execute mutation operation on a batch + + Args: + - batch: list of RowMutationEntry objects to send to server + - timeout: timeout in seconds. Used as operation_timeout and attempt_timeout. + If not given, will use table defaults + Returns: + - list of FailedMutationEntryError objects for mutations that failed. + FailedMutationEntryError objects will not contain index information + """ + try: + operation = _MutateRowsOperationAsync( + self._table.client._gapic_client, + self._table, + batch, + operation_timeout=self._operation_timeout, + attempt_timeout=self._attempt_timeout, + retryable_exceptions=self._retryable_errors, + ) + operation.start() + except MutationsExceptionGroup as e: + for subexc in e.exceptions: + subexc.index = None + return list(e.exceptions) + finally: + self._flow_control.remove_from_flow(batch) + return [] + + def _add_exceptions(self, excs: list[Exception]): + """ + Add new list of exceptions to internal store. To avoid unbounded memory, + the batcher will store the first and last _exception_list_limit exceptions, + and discard any in between. + """ + self._exceptions_since_last_raise += len(excs) + if excs and len(self._oldest_exceptions) < self._exception_list_limit: + addition_count = self._exception_list_limit - len(self._oldest_exceptions) + self._oldest_exceptions.extend(excs[:addition_count]) + excs = excs[addition_count:] + if excs: + self._newest_exceptions.extend(excs[-self._exception_list_limit :]) + + def _raise_exceptions(self): + """ + Raise any unreported exceptions from background flush operations + + Raises: + - MutationsExceptionGroup with all unreported exceptions + """ + if self._oldest_exceptions or self._newest_exceptions: + (oldest, self._oldest_exceptions) = (self._oldest_exceptions, []) + newest = list(self._newest_exceptions) + self._newest_exceptions.clear() + (entry_count, self._entries_processed_since_last_raise) = ( + self._entries_processed_since_last_raise, + 0, + ) + (exc_count, self._exceptions_since_last_raise) = ( + self._exceptions_since_last_raise, + 0, + ) + raise MutationsExceptionGroup.from_truncated_lists( + first_list=oldest, + last_list=newest, + total_excs=exc_count, + entry_count=entry_count, + ) + + def __enter__(self): + """For context manager API""" + return self + + def __exit__(self, exc_type, exc, tb): + """For context manager API""" + self.close() + + def close(self): + """Implementation purposely removed in sync mode""" + + def _on_exit(self): + """Called when program is exited. Raises warning if unflushed mutations remain""" + if not self.closed and self._staged_entries: + warnings.warn( + f"MutationsBatcher for table {self._table.table_name} was not closed. {len(self._staged_entries)} Unflushed mutations will not be sent to the server." + ) + + @staticmethod + def _create_bg_task(func, *args, **kwargs) -> concurrent.futures.Future[Any]: + """Implementation purposely removed in sync mode""" + + @staticmethod + def _wait_for_batch_results( + *tasks: concurrent.futures.Future[list[FailedMutationEntryError]] + | concurrent.futures.Future[None], + ) -> list[Exception]: + """Implementation purposely removed in sync mode""" + + +class _FlowControl_SyncGen(ABC): + """ + Manages flow control for batched mutations. Mutations are registered against + the FlowControl object before being sent, which will block if size or count + limits have reached capacity. As mutations completed, they are removed from + the FlowControl object, which will notify any blocked requests that there + is additional capacity. + + Flow limits are not hard limits. If a single mutation exceeds the configured + limits, it will be allowed as a single batch when the capacity is available. + """ + + def __init__(self, max_mutation_count: int, max_mutation_bytes: int): + """ + Args: + - max_mutation_count: maximum number of mutations to send in a single rpc. + This corresponds to individual mutations in a single RowMutationEntry. + - max_mutation_bytes: maximum number of bytes to send in a single rpc. + """ + self._max_mutation_count = max_mutation_count + self._max_mutation_bytes = max_mutation_bytes + if self._max_mutation_count < 1: + raise ValueError("max_mutation_count must be greater than 0") + if self._max_mutation_bytes < 1: + raise ValueError("max_mutation_bytes must be greater than 0") + self._capacity_condition = threading.Condition() + self._in_flight_mutation_count = 0 + self._in_flight_mutation_bytes = 0 + + def _has_capacity(self, additional_count: int, additional_size: int) -> bool: + """ + Checks if there is capacity to send a new entry with the given size and count + + FlowControl limits are not hard limits. If a single mutation exceeds + the configured flow limits, it will be sent in a single batch when + previous batches have completed. + + Args: + - additional_count: number of mutations in the pending entry + - additional_size: size of the pending entry + Returns: + - True if there is capacity to send the pending entry, False otherwise + """ + acceptable_size = max(self._max_mutation_bytes, additional_size) + acceptable_count = max(self._max_mutation_count, additional_count) + new_size = self._in_flight_mutation_bytes + additional_size + new_count = self._in_flight_mutation_count + additional_count + return new_size <= acceptable_size and new_count <= acceptable_count + + def remove_from_flow( + self, mutations: RowMutationEntry | list[RowMutationEntry] + ) -> None: + """ + Removes mutations from flow control. This method should be called once + for each mutation that was sent to add_to_flow, after the corresponding + operation is complete. + + Args: + - mutations: mutation or list of mutations to remove from flow control + """ + if not isinstance(mutations, list): + mutations = [mutations] + total_count = sum((len(entry.mutations) for entry in mutations)) + total_size = sum((entry.size() for entry in mutations)) + self._in_flight_mutation_count -= total_count + self._in_flight_mutation_bytes -= total_size + with self._capacity_condition: + self._capacity_condition.notify_all() + + def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): + """ + Generator function that registers mutations with flow control. As mutations + are accepted into the flow control, they are yielded back to the caller, + to be sent in a batch. If the flow control is at capacity, the generator + will block until there is capacity available. + + Args: + - mutations: list mutations to break up into batches + Yields: + - list of mutations that have reserved space in the flow control. + Each batch contains at least one mutation. + """ + if not isinstance(mutations, list): + mutations = [mutations] + start_idx = 0 + end_idx = 0 + while end_idx < len(mutations): + start_idx = end_idx + batch_mutation_count = 0 + with self._capacity_condition: + while end_idx < len(mutations): + next_entry = mutations[end_idx] + next_size = next_entry.size() + next_count = len(next_entry.mutations) + if ( + self._has_capacity(next_count, next_size) + and batch_mutation_count + next_count + <= _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + ): + end_idx += 1 + batch_mutation_count += next_count + self._in_flight_mutation_bytes += next_size + self._in_flight_mutation_count += next_count + elif start_idx != end_idx: + break + else: + self._capacity_condition.wait_for( + lambda: self._has_capacity(next_count, next_size) + ) + yield mutations[start_idx:end_idx] + + +class BigtableDataClient_SyncGen(ClientWithProject, ABC): + def __init__( + self, + *, + project: str | None = None, + pool_size: int = 3, + credentials: google.auth.credentials.Credentials | None = None, + client_options: dict[str, Any] + | "google.api_core.client_options.ClientOptions" + | None = None, + ): + """ + Create a client instance for the Bigtable Data API + + Client should be created within an async context (running event loop) + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + project: the project which the client acts on behalf of. + If not passed, falls back to the default inferred + from the environment. + pool_size: The number of grpc channels to maintain + in the internal channel pool. + credentials: + Thehe OAuth2 Credentials to use for this + client. If not passed (and if no ``_http`` object is + passed), falls back to the default inferred from the + environment. + client_options (Optional[Union[dict, google.api_core.client_options.ClientOptions]]): + Client options used to set user options + on the client. API Endpoint should be set through client_options. + Raises: + - RuntimeError if called outside of an async context (no running event loop) + - ValueError if pool_size is less than 1 + """ + transport_str = f"pooled_grpc_asyncio_{pool_size}" + transport = BigtableGrpcTransport.with_fixed_size(pool_size) + BigtableClientMeta._transport_registry[transport_str] = transport + client_info = DEFAULT_CLIENT_INFO + client_info.client_library_version = self._client_version() + if type(client_options) is dict: + client_options = client_options_lib.from_dict(client_options) + client_options = cast( + Optional[client_options_lib.ClientOptions], client_options + ) + self._emulator_host = os.getenv(BIGTABLE_EMULATOR) + if self._emulator_host is not None: + if credentials is None: + credentials = google.auth.credentials.AnonymousCredentials() + if project is None: + project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT + ClientWithProject.__init__( + self, + credentials=credentials, + project=project, + client_options=client_options, + ) + self._gapic_client = BigtableClient( + transport=transport_str, + credentials=credentials, + client_options=client_options, + client_info=client_info, + ) + self.transport = cast(BigtableGrpcTransport, self._gapic_client.transport) + self._active_instances: Set[_WarmedInstanceKey] = set() + self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} + self._channel_init_time = time.monotonic() + self._channel_refresh_tasks: list[asyncio.Task[None]] = [] + if self._emulator_host is not None: + warnings.warn( + "Connecting to Bigtable emulator at {}".format(self._emulator_host), + RuntimeWarning, + stacklevel=2, + ) + self.transport._grpc_channel = PooledChannel( + pool_size=pool_size, host=self._emulator_host, insecure=True + ) + self.transport._stubs = {} + self.transport._prep_wrapped_messages(client_info) + else: + try: + self._start_background_channel_refresh() + except RuntimeError: + warnings.warn( + f"{self.__class__.__name__} should be started in an asyncio event loop. Channel refresh will not be started", + RuntimeWarning, + stacklevel=2, + ) + + @staticmethod + def _client_version() -> str: + """Helper function to return the client version string for this client""" + return f"{google.cloud.bigtable.__version__}-data-async" + + def _start_background_channel_refresh(self) -> None: + """Implementation purposely removed in sync mode""" + + def close(self, timeout: float = 2.0): + """Implementation purposely removed in sync mode""" + + def _ping_and_warm_instances( + self, channel: grpc.aio.Channel, instance_key: _WarmedInstanceKey | None = None + ) -> list[BaseException | None]: + """Implementation purposely removed in sync mode""" + + def _manage_channel( + self, + channel_idx: int, + refresh_interval_min: float = 60 * 35, + refresh_interval_max: float = 60 * 45, + grace_period: float = 60 * 10, + ) -> None: + """ + Background coroutine that periodically refreshes and warms a grpc channel + + The backend will automatically close channels after 60 minutes, so + `refresh_interval` + `grace_period` should be < 60 minutes + + Runs continuously until the client is closed + + Args: + channel_idx: index of the channel in the transport's channel pool + refresh_interval_min: minimum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + refresh_interval_max: maximum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + grace_period: time to allow previous channel to serve existing + requests before closing, in seconds + """ + first_refresh = self._channel_init_time + random.uniform( + refresh_interval_min, refresh_interval_max + ) + next_sleep = max(first_refresh - time.monotonic(), 0) + if next_sleep > 0: + channel = self.transport.channels[channel_idx] + self._ping_and_warm_instances(channel) + while True: + time.sleep(next_sleep) + new_channel = self.transport.grpc_channel._create_channel() + self._ping_and_warm_instances(new_channel) + start_timestamp = time.time() + self.transport.replace_channel( + channel_idx, grace=grace_period, swap_sleep=10, new_channel=new_channel + ) + next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) + next_sleep = next_refresh - (time.time() - start_timestamp) + + def _register_instance(self, instance_id: str, owner: TableAsync) -> None: + """ + Registers an instance with the client, and warms the channel pool + for the instance + The client will periodically refresh grpc channel pool used to make + requests, and new channels will be warmed for each registered instance + Channels will not be refreshed unless at least one instance is registered + + Args: + - instance_id: id of the instance to register. + - owner: table that owns the instance. Owners will be tracked in + _instance_owners, and instances will only be unregistered when all + owners call _remove_instance_registration + """ + instance_name = self._gapic_client.instance_path(self.project, instance_id) + instance_key = _WarmedInstanceKey( + instance_name, owner.table_name, owner.app_profile_id + ) + self._instance_owners.setdefault(instance_key, set()).add(id(owner)) + if instance_name not in self._active_instances: + self._active_instances.add(instance_key) + if self._channel_refresh_tasks: + for channel in self.transport.channels: + self._ping_and_warm_instances(channel, instance_key) + else: + self._start_background_channel_refresh() + + def _remove_instance_registration( + self, instance_id: str, owner: TableAsync + ) -> bool: + """ + Removes an instance from the client's registered instances, to prevent + warming new channels for the instance + + If instance_id is not registered, or is still in use by other tables, returns False + + Args: + - instance_id: id of the instance to remove + - owner: table that owns the instance. Owners will be tracked in + _instance_owners, and instances will only be unregistered when all + owners call _remove_instance_registration + Returns: + - True if instance was removed + """ + instance_name = self._gapic_client.instance_path(self.project, instance_id) + instance_key = _WarmedInstanceKey( + instance_name, owner.table_name, owner.app_profile_id + ) + owner_list = self._instance_owners.get(instance_key, set()) + try: + owner_list.remove(id(owner)) + if len(owner_list) == 0: + self._active_instances.remove(instance_key) + return True + except KeyError: + return False + + def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: + """ + Returns a table instance for making data API requests. All arguments are passed + directly to the TableAsync constructor. + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults to 20 seconds + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to 60 seconds + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + """ + return TableAsync(self, instance_id, table_id, *args, **kwargs) + + def __enter__(self): + self._start_background_channel_refresh() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + self._gapic_client.__exit__(exc_type, exc_val, exc_tb) + + +class Table_SyncGen(ABC): + """ + Main Data API surface + + Table object maintains table_id, and app_profile_id context, and passes them with + each call + """ + + def __init__( + self, + client: BigtableDataClientAsync, + instance_id: str, + table_id: str, + app_profile_id: str | None = None, + *, + default_read_rows_operation_timeout: float = 600, + default_read_rows_attempt_timeout: float | None = 20, + default_mutate_rows_operation_timeout: float = 600, + default_mutate_rows_attempt_timeout: float | None = 60, + default_operation_timeout: float = 60, + default_attempt_timeout: float | None = 20, + default_read_rows_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + Aborted, + ), + default_mutate_rows_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + ), + default_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + ), + ): + """Implementation purposely removed in sync mode""" + + def read_rows_stream( + self, + query: ReadRowsQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> Iterable[Row]: + """ + Read a set of rows from the table, based on the specified query. + Returns an iterator to asynchronously stream back row data. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - query: contains details about which rows to return + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors + Returns: + - an asynchronous iterator that yields rows returned by the query + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + (operation_timeout, attempt_timeout) = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + retryable_excs = _get_retryable_errors(retryable_errors, self) + row_merger = _ReadRowsOperationAsync( + query, + self, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_exceptions=retryable_excs, + ) + return row_merger.start_operation() + + def read_rows( + self, + query: ReadRowsQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> list[Row]: + """ + Read a set of rows from the table, based on the specified query. + Retruns results as a list of Row objects when the request is complete. + For streamed results, use read_rows_stream. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - query: contains details about which rows to return + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + If None, defaults to the Table's default_read_rows_attempt_timeout, + or the operation_timeout if that is also None. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + - a list of Rows returned by the query + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + row_generator = self.read_rows_stream( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + return [row for row in row_generator] + + def read_row( + self, + row_key: str | bytes, + *, + row_filter: RowFilter | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> Row | None: + """ + Read a single row from the table, based on the specified key. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - query: contains details about which rows to return + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + - a Row object if the row exists, otherwise None + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + if row_key is None: + raise ValueError("row_key must be string or bytes") + query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) + results = self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + if len(results) == 0: + return None + return results[0] + + def read_rows_sharded( + self, + sharded_query: ShardedQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> list[Row]: + """Implementation purposely removed in sync mode""" + + def row_exists( + self, + row_key: str | bytes, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> bool: + """ + Return a boolean indicating whether the specified row exists in the table. + uses the filters: chain(limit cells per row = 1, strip value) + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - row_key: the key of the row to check + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + - a bool indicating whether the row exists + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + if row_key is None: + raise ValueError("row_key must be string or bytes") + strip_filter = StripValueTransformerFilter(flag=True) + limit_filter = CellsRowLimitFilter(1) + chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) + query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) + results = self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + return len(results) > 0 + + def sample_row_keys( + self, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> RowKeySamples: + """ + Return a set of RowKeySamples that delimit contiguous sections of the table of + approximately equal size + + RowKeySamples output can be used with ReadRowsQuery.shard() to create a sharded query that + can be parallelized across multiple backend nodes read_rows and read_rows_stream + requests will call sample_row_keys internally for this purpose when sharding is enabled + + RowKeySamples is simply a type alias for list[tuple[bytes, int]]; a list of + row_keys, along with offset positions in the table + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget.i + Defaults to the Table's default_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_retryable_errors. + Returns: + - a set of RowKeySamples the delimit contiguous sections of the table + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + (operation_timeout, attempt_timeout) = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + attempt_timeout_gen = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + retryable_excs = _get_retryable_errors(retryable_errors, self) + predicate = retries.if_exception_type(*retryable_excs) + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + metadata = _make_metadata(self.table_name, self.app_profile_id) + + def execute_rpc(): + results = self.client._gapic_client.sample_row_keys( + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=next(attempt_timeout_gen), + metadata=metadata, + retry=None, + ) + return [(s.row_key, s.offset_bytes) for s in results] + + return retries.retry_target_async( + execute_rpc, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) + + def mutations_batcher( + self, + *, + flush_interval: float | None = 5, + flush_limit_mutation_count: int | None = 1000, + flush_limit_bytes: int = 20 * _MB_SIZE, + flow_control_max_mutation_count: int = 100000, + flow_control_max_bytes: int = 100 * _MB_SIZE, + batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ) -> MutationsBatcherAsync: + """ + Returns a new mutations batcher instance. + + Can be used to iteratively add mutations that are flushed as a group, + to avoid excess network calls + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - flush_interval: Automatically flush every flush_interval seconds. If None, + a table default will be used + - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count + mutations are added across all entries. If None, this limit is ignored. + - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. + - flow_control_max_mutation_count: Maximum number of inflight mutations. + - flow_control_max_bytes: Maximum number of inflight bytes. + - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + Defaults to the Table's default_mutate_rows_operation_timeout + - batch_attempt_timeout: timeout for each individual request, in seconds. + Defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to batch_operation_timeout. + - batch_retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors. + Returns: + - a MutationsBatcherAsync context manager that can batch requests + """ + return MutationsBatcherAsync( + self, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_mutation_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=batch_operation_timeout, + batch_attempt_timeout=batch_attempt_timeout, + batch_retryable_errors=batch_retryable_errors, + ) + + def mutate_row( + self, + row_key: str | bytes, + mutations: list[Mutation] | Mutation, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ): + """ + Mutates a row atomically. + + Cells already present in the row are left unchanged unless explicitly changed + by ``mutation``. + + Idempotent operations (i.e, all mutations have an explicit timestamp) will be + retried on server failure. Non-idempotent operations will not. + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - row_key: the row to apply mutations to + - mutations: the set of mutations to apply to the row + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Only idempotent mutations will be retried. Defaults to the Table's + default_retryable_errors. + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing all + GoogleAPIError exceptions from any retries that failed + - GoogleAPIError: raised on non-idempotent operations that cannot be + safely retried. + - ValueError if invalid arguments are provided + """ + (operation_timeout, attempt_timeout) = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + if not mutations: + raise ValueError("No mutations provided") + mutations_list = mutations if isinstance(mutations, list) else [mutations] + if all((mutation.is_idempotent() for mutation in mutations_list)): + predicate = retries.if_exception_type( + *_get_retryable_errors(retryable_errors, self) + ) + else: + predicate = retries.if_exception_type() + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + target = partial( + self.client._gapic_client.mutate_row, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + mutations=[mutation._to_pb() for mutation in mutations_list], + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=attempt_timeout, + metadata=_make_metadata(self.table_name, self.app_profile_id), + retry=None, + ) + return retries.retry_target_async( + target, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) + + def bulk_mutate_rows( + self, + mutation_entries: list[RowMutationEntry], + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ): + """ + Applies mutations for multiple rows in a single batched request. + + Each individual RowMutationEntry is applied atomically, but separate entries + may be applied in arbitrary order (even for entries targetting the same row) + In total, the row_mutations can contain at most 100000 individual mutations + across all entries + + Idempotent entries (i.e., entries with mutations with explicit timestamps) + will be retried on failure. Non-idempotent will not, and will reported in a + raised exception group + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - mutation_entries: the batches of mutations to apply + Each entry will be applied atomically, but entries will be applied + in arbitrary order + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_mutate_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors + Raises: + - MutationsExceptionGroup if one or more mutations fails + Contains details about any failed entries in .exceptions + - ValueError if invalid arguments are provided + """ + (operation_timeout, attempt_timeout) = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + retryable_excs = _get_retryable_errors(retryable_errors, self) + operation = _MutateRowsOperationAsync( + self.client._gapic_client, + self, + mutation_entries, + operation_timeout, + attempt_timeout, + retryable_exceptions=retryable_excs, + ) + operation.start() + + def check_and_mutate_row( + self, + row_key: str | bytes, + predicate: RowFilter | None, + *, + true_case_mutations: Mutation | list[Mutation] | None = None, + false_case_mutations: Mutation | list[Mutation] | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> bool: + """ + Mutates a row atomically based on the output of a predicate filter + + Non-idempotent operation: will not be retried + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - row_key: the key of the row to mutate + - predicate: the filter to be applied to the contents of the specified row. + Depending on whether or not any results are yielded, + either true_case_mutations or false_case_mutations will be executed. + If None, checks that the row contains any values at all. + - true_case_mutations: + Changes to be atomically applied to the specified row if + predicate yields at least one cell when + applied to row_key. Entries are applied in order, + meaning that earlier mutations can be masked by later + ones. Must contain at least one entry if + false_case_mutations is empty, and at most 100000. + - false_case_mutations: + Changes to be atomically applied to the specified row if + predicate_filter does not yield any cells when + applied to row_key. Entries are applied in order, + meaning that earlier mutations can be masked by later + ones. Must contain at least one entry if + `true_case_mutations is empty, and at most 100000. + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will not be retried. Defaults to the Table's default_operation_timeout + Returns: + - bool indicating whether the predicate was true or false + Raises: + - GoogleAPIError exceptions from grpc call + """ + (operation_timeout, _) = _get_timeouts(operation_timeout, None, self) + if true_case_mutations is not None and ( + not isinstance(true_case_mutations, list) + ): + true_case_mutations = [true_case_mutations] + true_case_list = [m._to_pb() for m in true_case_mutations or []] + if false_case_mutations is not None and ( + not isinstance(false_case_mutations, list) + ): + false_case_mutations = [false_case_mutations] + false_case_list = [m._to_pb() for m in false_case_mutations or []] + metadata = _make_metadata(self.table_name, self.app_profile_id) + result = self.client._gapic_client.check_and_mutate_row( + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) + return result.predicate_matched + + def read_modify_write_row( + self, + row_key: str | bytes, + rules: ReadModifyWriteRule | list[ReadModifyWriteRule], + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> Row: + """ + Reads and modifies a row atomically according to input ReadModifyWriteRules, + and returns the contents of all modified cells + + The new value for the timestamp is the greater of the existing timestamp or + the current server time. + + Non-idempotent operation: will not be retried + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - row_key: the key of the row to apply read/modify/write rules to + - rules: A rule or set of rules to apply to the row. + Rules are applied in order, meaning that earlier rules will affect the + results of later ones. + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will not be retried. + Defaults to the Table's default_operation_timeout. + Returns: + - Row: containing cell data that was modified as part of the + operation + Raises: + - GoogleAPIError exceptions from grpc call + - ValueError if invalid arguments are provided + """ + (operation_timeout, _) = _get_timeouts(operation_timeout, None, self) + if operation_timeout <= 0: + raise ValueError("operation_timeout must be greater than 0") + if rules is not None and (not isinstance(rules, list)): + rules = [rules] + if not rules: + raise ValueError("rules must contain at least one item") + metadata = _make_metadata(self.table_name, self.app_profile_id) + result = self.client._gapic_client.read_modify_write_row( + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) + return Row._from_pb(result.row) + + def close(self): + """Called to close the Table instance and release any resources held by it.""" + self._register_instance_task.cancel() + self.client._remove_instance_registration(self.instance_id, self) + + def __enter__(self): + """ + Implement async context manager protocol + + Ensure registration task has time to run, so that + grpc channels will be warmed for the specified instance + """ + self._register_instance_task + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Implement async context manager protocol + + Unregister this instance with the client, so that + grpc channels will no longer be warmed + """ + self.close() diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py new file mode 100644 index 000000000..1841c814b --- /dev/null +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -0,0 +1,22 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + + +from google.cloud.bigtable.data._sync._autogen import _MutateRowsOperation_SyncGen + + +class _MutateRowsOperation(_MutateRowsOperation_SyncGen): + pass diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py new file mode 100644 index 000000000..f43822ba4 --- /dev/null +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -0,0 +1,22 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + + +from google.cloud.bigtable.data._sync._autogen import _ReadRowsOperation_SyncGen + + +class _ReadRowsOperation(_ReadRowsOperation_SyncGen): + pass diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py new file mode 100644 index 000000000..2d87f0758 --- /dev/null +++ b/google/cloud/bigtable/data/_sync/client.py @@ -0,0 +1,45 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import Any + +import google.auth.credentials + +from google.cloud.bigtable.data._sync._autogen import BigtableDataClient_SyncGen +from google.cloud.bigtable.data._sync._autogen import Table_SyncGen + + +class BigtableDataClient(BigtableDataClient_SyncGen): + def __init__( + self, + *, + project: str | None = None, + credentials: google.auth.credentials.Credentials | None = None, + client_options: dict[str, Any] + | "google.api_core.client_options.ClientOptions" + | None = None, + ): + # remove pool size option in sync client + super().__init__( + project=project, credentials=credentials, client_options=client_options + ) + + +class Table(Table_SyncGen): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # register table with client + self.client._register_instance(self.instance_id, self) diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py new file mode 100644 index 000000000..d7cbc428c --- /dev/null +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -0,0 +1,26 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + + +from google.cloud.bigtable.data._sync._autogen import _FlowControl_SyncGen +from google.cloud.bigtable.data._sync._autogen import MutationsBatcher_SyncGen + + +class _FlowControl(_FlowControl_SyncGen): + pass + +class MutationsBatcher(MutationsBatcher_SyncGen): + pass diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml new file mode 100644 index 000000000..88c3e5d1d --- /dev/null +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -0,0 +1,3 @@ +async_classes: + - "google.cloud.bigtable._async._read_rows._ReadRowsOperationAsync" + From 70fd73158fccef37f6a74bec28a8f88123fc2fa6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Apr 2024 16:12:56 -0700 Subject: [PATCH 011/360] keep docstrings on pass --- google/cloud/bigtable/data/__init__.py | 12 +-- google/cloud/bigtable/data/_sync/_autogen.py | 108 +++++++++++++++++-- sync_surface_generator.py | 24 +++-- 3 files changed, 117 insertions(+), 27 deletions(-) diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index fd44fe86c..cdb7622b6 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -20,10 +20,10 @@ from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync -from google.cloud.bigtable.data._sync.client import BigtableDataClient -from google.cloud.bigtable.data._sync.client import Table +# from google.cloud.bigtable.data._sync.client import BigtableDataClient +# from google.cloud.bigtable.data._sync.client import Table -from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher +# from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.read_rows_query import RowRange @@ -53,9 +53,9 @@ __version__: str = package_version.__version__ __all__ = ( - "BigtableDataClient", - "Table", - "MutationsBatcher", + # "BigtableDataClient", + # "Table", + # "MutationsBatcher", "BigtableDataClientAsync", "TableAsync", "MutationsBatcherAsync", diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 58abd4cd9..634bc2664 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -586,7 +586,7 @@ def __init__( def _start_flush_timer( self, interval: float | None ) -> concurrent.futures.Future[None]: - """Implementation purposely removed in sync mode""" + raise NotImplementedError("Function not implemented in sync class") def append(self, mutation_entry: RowMutationEntry): """ @@ -725,7 +725,7 @@ def __exit__(self, exc_type, exc, tb): self.close() def close(self): - """Implementation purposely removed in sync mode""" + """Flush queue and clean up resources""" def _on_exit(self): """Called when program is exited. Raises warning if unflushed mutations remain""" @@ -736,14 +736,25 @@ def _on_exit(self): @staticmethod def _create_bg_task(func, *args, **kwargs) -> concurrent.futures.Future[Any]: - """Implementation purposely removed in sync mode""" + raise NotImplementedError("Function not implemented in sync class") @staticmethod def _wait_for_batch_results( *tasks: concurrent.futures.Future[list[FailedMutationEntryError]] | concurrent.futures.Future[None], ) -> list[Exception]: - """Implementation purposely removed in sync mode""" + """ + Takes in a list of futures representing _execute_mutate_rows tasks, + waits for them to complete, and returns a list of errors encountered. + + Args: + - *tasks: futures representing _execute_mutate_rows or _flush_internal tasks + Returns: + - list of Exceptions encountered by any of the tasks. Errors are expected + to be FailedMutationEntryError, representing a failed mutation operation. + If a task fails with a different exception, it will be included in the + output list. Successful tasks will not be represented in the output list. + """ class _FlowControl_SyncGen(ABC): @@ -917,7 +928,7 @@ def __init__( project=project, client_options=client_options, ) - self._gapic_client = BigtableClient( + self._gapic_client = BigtableAsyncClient( transport=transport_str, credentials=credentials, client_options=client_options, @@ -955,15 +966,25 @@ def _client_version() -> str: return f"{google.cloud.bigtable.__version__}-data-async" def _start_background_channel_refresh(self) -> None: - """Implementation purposely removed in sync mode""" + raise NotImplementedError("Function not implemented in sync class") def close(self, timeout: float = 2.0): - """Implementation purposely removed in sync mode""" + """Cancel all background tasks""" def _ping_and_warm_instances( self, channel: grpc.aio.Channel, instance_key: _WarmedInstanceKey | None = None ) -> list[BaseException | None]: - """Implementation purposely removed in sync mode""" + """ + Prepares the backend for requests on a channel + + Pings each Bigtable instance registered in `_active_instances` on the client + + Args: + - channel: grpc channel to warm + - instance_key: if provided, only warm the instance associated with the key + Returns: + - sequence of results or exceptions from the ping requests + """ def _manage_channel( self, @@ -1147,7 +1168,43 @@ def __init__( ServiceUnavailable, ), ): - """Implementation purposely removed in sync mode""" + """ + Initialize a Table instance + + Must be created within an async context (running event loop) + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults to 20 seconds + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to 60 seconds + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + Raises: + - RuntimeError if called outside of an async context (no running event loop) + """ def read_rows_stream( self, @@ -1312,7 +1369,38 @@ def read_rows_sharded( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: - """Implementation purposely removed in sync mode""" + """ + Runs a sharded query in parallel, then return the results in a single list. + Results will be returned in the order of the input queries. + + This function is intended to be run on the results on a query.shard() call: + + ``` + table_shard_keys = await table.sample_row_keys() + query = ReadRowsQuery(...) + shard_queries = query.shard(table_shard_keys) + results = await table.read_rows_sharded(shard_queries) + ``` + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - sharded_query: a sharded query to execute + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Raises: + - ShardedReadRowsExceptionGroup: if any of the queries failed + - ValueError: if the query_list is empty + """ def row_exists( self, diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 77da234f6..abd34b3b0 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -90,10 +90,10 @@ def visit_AsyncFunctionDef(self, node): if node.name in self.drop_methods: return None elif node.name in self.pass_methods: - # replace with pass - node.body = [ast.Expr(value=ast.Str(s="Implementation purposely removed in sync mode"))] + # keep only docstring in pass mode + node.body = [ast.Expr(value=ast.Str(s=docstring))] elif node.name in self.error_methods: - self._create_error_node(node, "Function marked as unsupported in sync mode") + self._create_error_node(node, "Function not implemented in sync class") else: # check if the function contains non-replaced usage of asyncio func_ast = ast.parse(ast.unparse(node)) @@ -412,12 +412,12 @@ def transform_from_config(config_dict: dict): "AsyncIterator": "Iterator", "AsyncGenerator": "Generator", "StopAsyncIteration": "StopIteration", - "BigtableAsyncClient": "BigtableClient", "AsyncRetry": "Retry", - "PooledBigtableGrpcAsyncIOTransport": "BigtableGrpcTransport", "Awaitable": None, "pytest_asyncio": "pytest", "AsyncMock": "mock.Mock", + "PooledBigtableGrpcAsyncIOTransport": "BigtableGrpcTransport", + # "BigtableAsyncClient": "BigtableClient", # "_ReadRowsOperation": "_ReadRowsOperation_Sync", # "Table": "Table_Sync", # "BigtableDataClient": "BigtableDataClient_Sync", @@ -430,9 +430,9 @@ def transform_from_config(config_dict: dict): { "path": "google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync", "autogen_sync_name": "_ReadRowsOperation_SyncGen", - "pass_methods": [], - "drop_methods": [], - "error_methods": [], + "pass_methods": [], # useful if you want to keep docstring + "drop_methods": [], # useful when the function can be completely removed + "error_methods": [], # useful when the implementation and docstring depend heavily on asyncio }, { "path": "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync", @@ -441,7 +441,8 @@ def transform_from_config(config_dict: dict): { "path": "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync", "autogen_sync_name": "MutationsBatcher_SyncGen", - "pass_methods": ["_start_flush_timer", "close", "_create_bg_task", "_wait_for_batch_results"] + "pass_methods": ["close", "_wait_for_batch_results"], + "error_methods": ["_create_bg_task", "_start_flush_timer"] }, { "path": "google.cloud.bigtable.data._async.mutations_batcher._FlowControlAsync", @@ -450,12 +451,13 @@ def transform_from_config(config_dict: dict): { "path": "google.cloud.bigtable.data._async.client.BigtableDataClientAsync", "autogen_sync_name": "BigtableDataClient_SyncGen", - "pass_methods": ["_start_background_channel_refresh", "close", "_ping_and_warm_instances"] + "pass_methods": ["close", "_ping_and_warm_instances"], + "error_methods": ["_start_background_channel_refresh"], }, { "path": "google.cloud.bigtable.data._async.client.TableAsync", "autogen_sync_name": "Table_SyncGen", - "pass_methods": ["__init__", "read_rows_sharded"] + "pass_methods": ["__init__", "read_rows_sharded"], }, ] } From 5c96cd3e999c7e4e6dafc99e4b23d98e6d973dd9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Apr 2024 17:30:26 -0700 Subject: [PATCH 012/360] improved import replacement code --- google/cloud/bigtable/data/_sync/_autogen.py | 9 +++-- sync_surface_generator.py | 41 ++++++++------------ 2 files changed, 22 insertions(+), 28 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 634bc2664..76460b091 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -83,7 +83,6 @@ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter from google.cloud.bigtable.mutations_batcher import MutationsBatcherAsync from google.cloud.bigtable.mutations_batcher import _FlowControlAsync -from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( @@ -907,7 +906,7 @@ def __init__( - ValueError if pool_size is less than 1 """ transport_str = f"pooled_grpc_asyncio_{pool_size}" - transport = BigtableGrpcTransport.with_fixed_size(pool_size) + transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) BigtableClientMeta._transport_registry[transport_str] = transport client_info = DEFAULT_CLIENT_INFO client_info.client_library_version = self._client_version() @@ -934,7 +933,9 @@ def __init__( client_options=client_options, client_info=client_info, ) - self.transport = cast(BigtableGrpcTransport, self._gapic_client.transport) + self.transport = cast( + PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport + ) self._active_instances: Set[_WarmedInstanceKey] = set() self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() @@ -972,7 +973,7 @@ def close(self, timeout: float = 2.0): """Cancel all background tasks""" def _ping_and_warm_instances( - self, channel: grpc.aio.Channel, instance_key: _WarmedInstanceKey | None = None + self, channel: grpc.Channel, instance_key: _WarmedInstanceKey | None = None ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel diff --git a/sync_surface_generator.py b/sync_surface_generator.py index abd34b3b0..0c7cfffed 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -143,17 +143,23 @@ def visit_Await(self, node): return self.visit(node.value) def visit_Attribute(self, node): - if ( - isinstance(node.value, ast.Name) - and isinstance(node.value.ctx, ast.Load) - and node.value.id == "asyncio" - and f"asyncio.{node.attr}" in self.import_replacements - ): - replacement = self.import_replacements[f"asyncio.{node.attr}"] + parts = [] + attr_node = node + while isinstance(attr_node, ast.Attribute): + parts.append(attr_node.attr) + attr_node = attr_node.value + if isinstance(attr_node, ast.Name): + parts.append(attr_node.id) + full_name = ".".join(parts[::-1]) + + if full_name in self.import_replacements: + # replace from import_replacements + ieplacement = self.import_replacements[full_name] return ast.copy_location(ast.parse(replacement, mode="eval").body, node) - elif isinstance(node, ast.Attribute) and node.attr in self.name_replacements: + elif node.attr in self.name_replacements: + # replace from name_replacements new_node = ast.copy_location( - ast.Attribute(node.value, self.name_replacements[node.attr], node.ctx), node + ast.Attribute(self.visit(node.value), self.name_replacements[node.attr], node.ctx), node ) return new_node return node @@ -399,9 +405,8 @@ def transform_from_config(config_dict: dict): "typing.AsyncIterator": "typing.Iterator", "typing.AsyncGenerator": "typing.Generator", "grpc.aio.Channel": "grpc.Channel", - # "google.cloud.bigtable_v2.services.bigtable.async_client.BigtableAsyncClient": "google.cloud.bigtable_v2.services.bigtable.client.BigtableClient", - # "google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio.PooledBigtableGrpcAsyncIOTransport": "google.cloud.bigtable_v2.services.bigtable.transports.BigtableGrpcTransport", - }, # replace imports with corresponding sync version. Does not touch the code, only import lines + "google.cloud.bigtable_v2.services.bigtable.async_client.BigtableAsyncClient": "google.cloud.bigtable_v2.services.bigtable.client.BigtableClient", + }, # replace imports with corresponding sync version. "name_replacements": { "__anext__": "__next__", "__aiter__": "__iter__", @@ -412,19 +417,7 @@ def transform_from_config(config_dict: dict): "AsyncIterator": "Iterator", "AsyncGenerator": "Generator", "StopAsyncIteration": "StopIteration", - "AsyncRetry": "Retry", "Awaitable": None, - "pytest_asyncio": "pytest", - "AsyncMock": "mock.Mock", - "PooledBigtableGrpcAsyncIOTransport": "BigtableGrpcTransport", - # "BigtableAsyncClient": "BigtableClient", - # "_ReadRowsOperation": "_ReadRowsOperation_Sync", - # "Table": "Table_Sync", - # "BigtableDataClient": "BigtableDataClient_Sync", - # "ReadRowsIterator": "ReadRowsIterator_Sync", - # "_MutateRowsOperation": "_MutateRowsOperation_Sync", - # "MutationsBatcher": "MutationsBatcher_Sync", - # "_FlowControl": "_FlowControl_Sync", }, # performs find/replace for these terms in docstrings and generated code "classes": [ { From 86b2e251f861928d181edae083d8c2c75a565324 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 9 Apr 2024 17:32:14 -0700 Subject: [PATCH 013/360] renamed the two types of replacements --- google/cloud/bigtable/data/_sync/_autogen.py | 3 ++ sync_surface_generator.py | 57 ++++++++++---------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 76460b091..b0437bc2c 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -85,6 +85,9 @@ from google.cloud.bigtable.mutations_batcher import _FlowControlAsync from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, +) from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( PooledChannel, ) diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 0c7cfffed..a163a0961 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -39,31 +39,30 @@ class AsyncToSyncTransformer(ast.NodeTransformer): outside of this autogeneration system are always applied """ - def __init__(self, *, name=None, import_replacements=None, name_replacements=None, drop_methods=None, pass_methods=None, error_methods=None): + def __init__(self, *, name=None, module_replacements=None, text_replacements=None, drop_methods=None, pass_methods=None, error_methods=None): """ Args: - name: the name of the class being processed. Just used in exceptions - - import_replacements: dict of (module, name) to (module, name) replacement import statements - For example, {("foo", "bar"): ("baz", "qux")} will replace "from foo import bar" with "from baz import qux" - - name_replacements: dict of names to replace directly in the source code and docstrings + - module_replacements: modules to replace + - text_replacements: dict of text to replace directly in the source code and docstrings - drop_methods: list of method names to drop from the class - pass_methods: list of method names to replace with "pass" in the class - error_methods: list of method names to replace with "raise NotImplementedError" in the class """ self.name = name - self.import_replacements = import_replacements or {} - self.name_replacements = name_replacements or {} + self.module_replacements = module_replacements or {} + self.text_replacements = text_replacements or {} self.drop_methods = drop_methods or [] self.pass_methods = pass_methods or [] self.error_methods = error_methods or [] def update_docstring(self, docstring): """ - Update docstring to replace any key words in the name_replacements dict + Update docstring to replace any key words in the text_replacements dict """ if not docstring: return docstring - for key_word, replacement in self.name_replacements.items(): + for key_word, replacement in self.text_replacements.items(): docstring = docstring.replace(f" {key_word} ", f" {replacement} ") if "\n" in docstring: # if multiline docstring, add linebreaks to put the """ on a separate line @@ -102,7 +101,7 @@ def visit_AsyncFunctionDef(self, node): and isinstance(n.func, ast.Attribute) \ and isinstance(n.func.value, ast.Name) \ and n.func.value.id == "asyncio" \ - and f"asyncio.{n.func.attr}" not in self.import_replacements: + and f"asyncio.{n.func.attr}" not in self.module_replacements: path_str = f"{self.name}.{node.name}" if self.name else node.name raise RuntimeError( f"{path_str} contains unhandled asyncio calls: {n.func.attr}. Add method to drop_methods, pass_methods, or error_methods to handle safely." @@ -115,7 +114,7 @@ def visit_AsyncFunctionDef(self, node): ] return ast.copy_location( ast.FunctionDef( - self.name_replacements.get(node.name, node.name), + self.text_replacements.get(node.name, node.name), self.visit(node.args), [self.visit(stmt) for stmt in node.body], [self.visit(stmt) for stmt in node.decorator_list], @@ -129,7 +128,7 @@ def visit_Call(self, node): if isinstance(node.func, ast.Attribute) and isinstance( node.func.value, ast.Name ): - node.func.value.id = self.name_replacements.get(node.func.value.id, node.func.value.id) + node.func.value.id = self.text_replacements.get(node.func.value.id, node.func.value.id) return ast.copy_location( ast.Call( self.visit(node.func), @@ -152,20 +151,20 @@ def visit_Attribute(self, node): parts.append(attr_node.id) full_name = ".".join(parts[::-1]) - if full_name in self.import_replacements: - # replace from import_replacements - ieplacement = self.import_replacements[full_name] + if full_name in self.module_replacements: + # replace from module_replacements + replacement = self.module_replacements[full_name] return ast.copy_location(ast.parse(replacement, mode="eval").body, node) - elif node.attr in self.name_replacements: - # replace from name_replacements + elif node.attr in self.text_replacements: + # replace from text_replacements new_node = ast.copy_location( - ast.Attribute(self.visit(node.value), self.name_replacements[node.attr], node.ctx), node + ast.Attribute(self.visit(node.value), self.text_replacements[node.attr], node.ctx), node ) return new_node return node def visit_Name(self, node): - node.id = self.name_replacements.get(node.id, node.id) + node.id = self.text_replacements.get(node.id, node.id) return node def visit_AsyncFor(self, node): @@ -220,7 +219,7 @@ def visit_Subscript(self, node): hasattr(node, "value") and isinstance(node.value, ast.Name) and node.value.id == "AsyncGenerator" - and self.name_replacements.get(node.value.id, "") == "Generator" + and self.text_replacements.get(node.value.id, "") == "Generator" ): # Generator has different argument signature than AsyncGenerator return ast.copy_location( @@ -241,7 +240,7 @@ def visit_Subscript(self, node): elif ( hasattr(node, "value") and isinstance(node.value, ast.Name) - and self.name_replacements.get(node.value.id, False) is None + and self.text_replacements.get(node.value.id, False) is None ): # needed for Awaitable return self.visit(node.slice) @@ -268,7 +267,7 @@ def _create_error_node(node, error_msg): def get_imports(self, filename): """ - Get the imports from a file, and do a find-and-replace against import_replacements + Get the imports from a file, and do a find-and-replace against module_replacements """ imports = set() with open(filename, "r") as f: @@ -278,14 +277,14 @@ def get_imports(self, filename): for alias in node.names: if isinstance(node, ast.Import): # import statments - new_import = self.import_replacements.get(alias.name, alias.name) + new_import = self.module_replacements.get(alias.name, alias.name) imports.add(ast.parse(f"import {new_import}").body[0]) else: # import from statements # break into individual components full_path = f"{node.module}.{alias.name}" - if full_path in self.import_replacements: - full_path = self.import_replacements[full_path] + if full_path in self.module_replacements: + full_path = self.module_replacements[full_path] module, name = full_path.rsplit(".", 1) # don't import from same file if module == ".": @@ -349,8 +348,8 @@ def transform_from_config(config_dict: dict): module_path, class_name = class_dict.pop("path").rsplit(".", 1) class_object = getattr(importlib.import_module(module_path), class_name) # add globals to class_dict - class_dict["import_replacements"] = {**config_dict.get("import_replacements", {}), **class_dict.get("import_replacements", {})} - class_dict["name_replacements"] = {**config_dict.get("name_replacements", {}), **class_dict.get("name_replacements", {})} + class_dict["module_replacements"] = {**config_dict.get("module_replacements", {}), **class_dict.get("module_replacements", {})} + class_dict["text_replacements"] = {**config_dict.get("text_replacements", {}), **class_dict.get("text_replacements", {})} # transform class tree_body, imports = transform_class(class_object, **class_dict) # update combined data @@ -395,7 +394,7 @@ def transform_from_config(config_dict: dict): if __name__ == "__main__": config = { - "import_replacements": { + "module_replacements": { "asyncio.sleep": "time.sleep", "asyncio.Queue": "queue.Queue", "asyncio.Condition": "threading.Condition", @@ -406,8 +405,8 @@ def transform_from_config(config_dict: dict): "typing.AsyncGenerator": "typing.Generator", "grpc.aio.Channel": "grpc.Channel", "google.cloud.bigtable_v2.services.bigtable.async_client.BigtableAsyncClient": "google.cloud.bigtable_v2.services.bigtable.client.BigtableClient", - }, # replace imports with corresponding sync version. - "name_replacements": { + }, # replace modules with corresponding sync version. + "text_replacements": { "__anext__": "__next__", "__aiter__": "__iter__", "__aenter__": "__enter__", From db16d0fbadb0270df6dab2e78f5180d5c706f3ca Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 10 Apr 2024 14:46:36 -0700 Subject: [PATCH 014/360] use concrete names in generator --- google/cloud/bigtable/data/_sync/_autogen.py | 66 +++++++++----------- sync_surface_generator.py | 18 +++++- 2 files changed, 47 insertions(+), 37 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index b0437bc2c..2b1bd19e9 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -43,19 +43,13 @@ from google.api_core.exceptions import ServiceUnavailable from google.api_core.retry import exponential_sleep_generator from google.cloud.bigtable._mutate_rows import _EntryWithProto -from google.cloud.bigtable._mutate_rows import _MutateRowsOperationAsync -from google.cloud.bigtable._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable._read_rows import _ResetRow -from google.cloud.bigtable.client import BigtableDataClientAsync from google.cloud.bigtable.client import TableAsync from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT from google.cloud.bigtable.data._async._mutate_rows import ( _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, ) -from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync -from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data._async.client import TableAsync -from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery @@ -66,6 +60,12 @@ from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation +from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation +from google.cloud.bigtable.data._sync.client import BigtableDataClient +from google.cloud.bigtable.data._sync.client import Table +from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher +from google.cloud.bigtable.data._sync.mutations_batcher import _FlowControl from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup @@ -81,8 +81,6 @@ from google.cloud.bigtable.data.row_filters import RowFilter from google.cloud.bigtable.data.row_filters import RowFilterChain from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter -from google.cloud.bigtable.mutations_batcher import MutationsBatcherAsync -from google.cloud.bigtable.mutations_batcher import _FlowControlAsync from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( @@ -565,7 +563,7 @@ def __init__( self._table = table self._staged_entries: list[RowMutationEntry] = [] (self._staged_count, self._staged_bytes) = (0, 0) - self._flow_control = _FlowControlAsync( + self._flow_control = _FlowControl( flow_control_max_mutation_count, flow_control_max_bytes ) self._flush_limit_bytes = flush_limit_bytes @@ -661,7 +659,7 @@ def _execute_mutate_rows( FailedMutationEntryError objects will not contain index information """ try: - operation = _MutateRowsOperationAsync( + operation = _MutateRowsOperation( self._table.client._gapic_client, self._table, batch, @@ -887,7 +885,7 @@ def __init__( Client should be created within an async context (running event loop) - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1034,7 +1032,7 @@ def _manage_channel( next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = next_refresh - (time.time() - start_timestamp) - def _register_instance(self, instance_id: str, owner: TableAsync) -> None: + def _register_instance(self, instance_id: str, owner: Table) -> None: """ Registers an instance with the client, and warms the channel pool for the instance @@ -1061,9 +1059,7 @@ def _register_instance(self, instance_id: str, owner: TableAsync) -> None: else: self._start_background_channel_refresh() - def _remove_instance_registration( - self, instance_id: str, owner: TableAsync - ) -> bool: + def _remove_instance_registration(self, instance_id: str, owner: Table) -> bool: """ Removes an instance from the client's registered instances, to prevent warming new channels for the instance @@ -1091,10 +1087,10 @@ def _remove_instance_registration( except KeyError: return False - def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: + def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> Table: """ Returns a table instance for making data API requests. All arguments are passed - directly to the TableAsync constructor. + directly to the Table constructor. Args: instance_id: The Bigtable instance ID to associate with this client. @@ -1126,7 +1122,7 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs encountered during all other operations. Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) """ - return TableAsync(self, instance_id, table_id, *args, **kwargs) + return Table(self, instance_id, table_id, *args, **kwargs) def __enter__(self): self._start_background_channel_refresh() @@ -1147,7 +1143,7 @@ class Table_SyncGen(ABC): def __init__( self, - client: BigtableDataClientAsync, + client: BigtableDataClient, instance_id: str, table_id: str, app_profile_id: str | None = None, @@ -1226,7 +1222,7 @@ def read_rows_stream( Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1253,7 +1249,7 @@ def read_rows_stream( operation_timeout, attempt_timeout, self ) retryable_excs = _get_retryable_errors(retryable_errors, self) - row_merger = _ReadRowsOperationAsync( + row_merger = _ReadRowsOperation( query, self, operation_timeout=operation_timeout, @@ -1279,7 +1275,7 @@ def read_rows( Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1328,7 +1324,7 @@ def read_row( Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1386,7 +1382,7 @@ def read_rows_sharded( results = await table.read_rows_sharded(shard_queries) ``` - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1419,7 +1415,7 @@ def row_exists( Return a boolean indicating whether the specified row exists in the table. uses the filters: chain(limit cells per row = 1, strip value) - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1475,7 +1471,7 @@ def sample_row_keys( RowKeySamples is simply a type alias for list[tuple[bytes, int]]; a list of row_keys, along with offset positions in the table - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1538,14 +1534,14 @@ def mutations_batcher( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - ) -> MutationsBatcherAsync: + ) -> MutationsBatcher: """ Returns a new mutations batcher instance. Can be used to iteratively add mutations that are flushed as a group, to avoid excess network calls - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1564,9 +1560,9 @@ def mutations_batcher( - batch_retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_mutate_rows_retryable_errors. Returns: - - a MutationsBatcherAsync context manager that can batch requests + - a MutationsBatcher context manager that can batch requests """ - return MutationsBatcherAsync( + return MutationsBatcher( self, flush_interval=flush_interval, flush_limit_mutation_count=flush_limit_mutation_count, @@ -1597,7 +1593,7 @@ def mutate_row( Idempotent operations (i.e, all mutations have an explicit timestamp) will be retried on server failure. Non-idempotent operations will not. - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1674,7 +1670,7 @@ def bulk_mutate_rows( will be retried on failure. Non-idempotent will not, and will reported in a raised exception group - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1700,7 +1696,7 @@ def bulk_mutate_rows( operation_timeout, attempt_timeout, self ) retryable_excs = _get_retryable_errors(retryable_errors, self) - operation = _MutateRowsOperationAsync( + operation = _MutateRowsOperation( self.client._gapic_client, self, mutation_entries, @@ -1724,7 +1720,7 @@ def check_and_mutate_row( Non-idempotent operation: will not be retried - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1795,7 +1791,7 @@ def read_modify_write_row( Non-idempotent operation: will not be retried - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: diff --git a/sync_surface_generator.py b/sync_surface_generator.py index a163a0961..ac8d4f785 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -342,6 +342,14 @@ def transform_from_config(config_dict: dict): # initialize new tree and import list combined_tree = ast.parse("") combined_imports = set() + # add new concrete classes to text_replacements + global_text_replacements = config_dict.get("text_replacements", {}) + for class_dict in config_dict["classes"]: + if "concrete_path" in class_dict: + class_name = class_dict["path"].rsplit(".", 1)[1] + new_module, new_class_name = class_dict.pop("concrete_path").rsplit(".", 1) + global_text_replacements[class_name] = new_class_name + combined_imports.add(ast.parse(f"from {new_module} import {new_class_name}").body[0]) # process each class for class_dict in config_dict["classes"]: # convert string class path into class object @@ -349,7 +357,7 @@ def transform_from_config(config_dict: dict): class_object = getattr(importlib.import_module(module_path), class_name) # add globals to class_dict class_dict["module_replacements"] = {**config_dict.get("module_replacements", {}), **class_dict.get("module_replacements", {})} - class_dict["text_replacements"] = {**config_dict.get("text_replacements", {}), **class_dict.get("text_replacements", {})} + class_dict["text_replacements"] = {**global_text_replacements, **class_dict.get("text_replacements", {})} # transform class tree_body, imports = transform_class(class_object, **class_dict) # update combined data @@ -422,33 +430,39 @@ def transform_from_config(config_dict: dict): { "path": "google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync", "autogen_sync_name": "_ReadRowsOperation_SyncGen", + "concrete_path": "google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", "pass_methods": [], # useful if you want to keep docstring "drop_methods": [], # useful when the function can be completely removed "error_methods": [], # useful when the implementation and docstring depend heavily on asyncio }, { - "path": "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync", + "path": "google.cloud.bigtable.data._async._mutate_rows._MutateRowsOperationAsync", "autogen_sync_name": "_MutateRowsOperation_SyncGen", + "concrete_path": "google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", }, { "path": "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync", "autogen_sync_name": "MutationsBatcher_SyncGen", + "concrete_path": "google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", "pass_methods": ["close", "_wait_for_batch_results"], "error_methods": ["_create_bg_task", "_start_flush_timer"] }, { "path": "google.cloud.bigtable.data._async.mutations_batcher._FlowControlAsync", "autogen_sync_name": "_FlowControl_SyncGen", + "concrete_path": "google.cloud.bigtable.data._sync.mutations_batcher._FlowControl", }, { "path": "google.cloud.bigtable.data._async.client.BigtableDataClientAsync", "autogen_sync_name": "BigtableDataClient_SyncGen", + "concrete_path": "google.cloud.bigtable.data._sync.client.BigtableDataClient", "pass_methods": ["close", "_ping_and_warm_instances"], "error_methods": ["_start_background_channel_refresh"], }, { "path": "google.cloud.bigtable.data._async.client.TableAsync", "autogen_sync_name": "Table_SyncGen", + "concrete_path": "google.cloud.bigtable.data._sync.client.Table", "pass_methods": ["__init__", "read_rows_sharded"], }, ] From 8c9676d96320598927d207aed30e09b1d452ae9d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 10 Apr 2024 14:52:56 -0700 Subject: [PATCH 015/360] fixed import paths --- google/cloud/bigtable/data/_sync/_autogen.py | 5 ++--- sync_surface_generator.py | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 2b1bd19e9..2ddc65fe5 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -42,13 +42,12 @@ from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import ServiceUnavailable from google.api_core.retry import exponential_sleep_generator -from google.cloud.bigtable._mutate_rows import _EntryWithProto -from google.cloud.bigtable._read_rows import _ResetRow -from google.cloud.bigtable.client import TableAsync from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT +from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto from google.cloud.bigtable.data._async._mutate_rows import ( _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, ) +from google.cloud.bigtable.data._async._read_rows import _ResetRow from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data._helpers import RowKeySamples diff --git a/sync_surface_generator.py b/sync_surface_generator.py index ac8d4f785..1a09077b1 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -326,13 +326,12 @@ def transform_class(in_obj: Type, **kwargs): # imports.add(ast.parse(f"import {g}").body[0]) # add locals from file, in case they are needed if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): - file_basename = os.path.splitext(os.path.basename(filename))[0] with open(filename, "r") as f: for node in ast.walk(ast.parse(f.read(), filename)): if isinstance(node, ast.ClassDef): imports.add( ast.parse( - f"from google.cloud.bigtable.{file_basename} import {node.name}" + f"from {in_obj.__module__} import {node.name}" ).body[0] ) return ast_tree.body, imports From d4ca86442f0ad9d229551db3f5b2912824a5e8de Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 10 Apr 2024 14:58:27 -0700 Subject: [PATCH 016/360] avoid circular import --- google/cloud/bigtable/data/_sync/_autogen.py | 82 +++++++++++--------- sync_surface_generator.py | 4 +- 2 files changed, 45 insertions(+), 41 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 2ddc65fe5..ef048b2b0 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -59,12 +59,6 @@ from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory -from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation -from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation -from google.cloud.bigtable.data._sync.client import BigtableDataClient -from google.cloud.bigtable.data._sync.client import Table -from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher -from google.cloud.bigtable.data._sync.mutations_batcher import _FlowControl from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup @@ -562,8 +556,10 @@ def __init__( self._table = table self._staged_entries: list[RowMutationEntry] = [] (self._staged_count, self._staged_bytes) = (0, 0) - self._flow_control = _FlowControl( - flow_control_max_mutation_count, flow_control_max_bytes + self._flow_control = ( + google.cloud.bigtable.data._sync.mutations_batcher._FlowControl( + flow_control_max_mutation_count, flow_control_max_bytes + ) ) self._flush_limit_bytes = flush_limit_bytes self._flush_limit_count = ( @@ -658,13 +654,15 @@ def _execute_mutate_rows( FailedMutationEntryError objects will not contain index information """ try: - operation = _MutateRowsOperation( - self._table.client._gapic_client, - self._table, - batch, - operation_timeout=self._operation_timeout, - attempt_timeout=self._attempt_timeout, - retryable_exceptions=self._retryable_errors, + operation = ( + google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation( + self._table.client._gapic_client, + self._table, + batch, + operation_timeout=self._operation_timeout, + attempt_timeout=self._attempt_timeout, + retryable_exceptions=self._retryable_errors, + ) ) operation.start() except MutationsExceptionGroup as e: @@ -884,7 +882,7 @@ def __init__( Client should be created within an async context (running event loop) - Warning: BigtableDataClient is currently in preview, and is not + Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1031,7 +1029,9 @@ def _manage_channel( next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = next_refresh - (time.time() - start_timestamp) - def _register_instance(self, instance_id: str, owner: Table) -> None: + def _register_instance( + self, instance_id: str, owner: google.cloud.bigtable.data._sync.client.Table + ) -> None: """ Registers an instance with the client, and warms the channel pool for the instance @@ -1058,7 +1058,9 @@ def _register_instance(self, instance_id: str, owner: Table) -> None: else: self._start_background_channel_refresh() - def _remove_instance_registration(self, instance_id: str, owner: Table) -> bool: + def _remove_instance_registration( + self, instance_id: str, owner: google.cloud.bigtable.data._sync.client.Table + ) -> bool: """ Removes an instance from the client's registered instances, to prevent warming new channels for the instance @@ -1086,10 +1088,12 @@ def _remove_instance_registration(self, instance_id: str, owner: Table) -> bool: except KeyError: return False - def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> Table: + def get_table( + self, instance_id: str, table_id: str, *args, **kwargs + ) -> google.cloud.bigtable.data._sync.client.Table: """ Returns a table instance for making data API requests. All arguments are passed - directly to the Table constructor. + directly to the google.cloud.bigtable.data._sync.client.Table constructor. Args: instance_id: The Bigtable instance ID to associate with this client. @@ -1121,7 +1125,9 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> Table: encountered during all other operations. Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) """ - return Table(self, instance_id, table_id, *args, **kwargs) + return google.cloud.bigtable.data._sync.client.Table( + self, instance_id, table_id, *args, **kwargs + ) def __enter__(self): self._start_background_channel_refresh() @@ -1142,7 +1148,7 @@ class Table_SyncGen(ABC): def __init__( self, - client: BigtableDataClient, + client: google.cloud.bigtable.data._sync.client.BigtableDataClient, instance_id: str, table_id: str, app_profile_id: str | None = None, @@ -1221,7 +1227,7 @@ def read_rows_stream( Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. - Warning: BigtableDataClient is currently in preview, and is not + Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1248,7 +1254,7 @@ def read_rows_stream( operation_timeout, attempt_timeout, self ) retryable_excs = _get_retryable_errors(retryable_errors, self) - row_merger = _ReadRowsOperation( + row_merger = google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation( query, self, operation_timeout=operation_timeout, @@ -1274,7 +1280,7 @@ def read_rows( Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. - Warning: BigtableDataClient is currently in preview, and is not + Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1323,7 +1329,7 @@ def read_row( Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. - Warning: BigtableDataClient is currently in preview, and is not + Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1381,7 +1387,7 @@ def read_rows_sharded( results = await table.read_rows_sharded(shard_queries) ``` - Warning: BigtableDataClient is currently in preview, and is not + Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1414,7 +1420,7 @@ def row_exists( Return a boolean indicating whether the specified row exists in the table. uses the filters: chain(limit cells per row = 1, strip value) - Warning: BigtableDataClient is currently in preview, and is not + Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1470,7 +1476,7 @@ def sample_row_keys( RowKeySamples is simply a type alias for list[tuple[bytes, int]]; a list of row_keys, along with offset positions in the table - Warning: BigtableDataClient is currently in preview, and is not + Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1533,14 +1539,14 @@ def mutations_batcher( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - ) -> MutationsBatcher: + ) -> google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher: """ Returns a new mutations batcher instance. Can be used to iteratively add mutations that are flushed as a group, to avoid excess network calls - Warning: BigtableDataClient is currently in preview, and is not + Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1559,9 +1565,9 @@ def mutations_batcher( - batch_retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_mutate_rows_retryable_errors. Returns: - - a MutationsBatcher context manager that can batch requests + - a google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher context manager that can batch requests """ - return MutationsBatcher( + return google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher( self, flush_interval=flush_interval, flush_limit_mutation_count=flush_limit_mutation_count, @@ -1592,7 +1598,7 @@ def mutate_row( Idempotent operations (i.e, all mutations have an explicit timestamp) will be retried on server failure. Non-idempotent operations will not. - Warning: BigtableDataClient is currently in preview, and is not + Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1669,7 +1675,7 @@ def bulk_mutate_rows( will be retried on failure. Non-idempotent will not, and will reported in a raised exception group - Warning: BigtableDataClient is currently in preview, and is not + Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1695,7 +1701,7 @@ def bulk_mutate_rows( operation_timeout, attempt_timeout, self ) retryable_excs = _get_retryable_errors(retryable_errors, self) - operation = _MutateRowsOperation( + operation = google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation( self.client._gapic_client, self, mutation_entries, @@ -1719,7 +1725,7 @@ def check_and_mutate_row( Non-idempotent operation: will not be retried - Warning: BigtableDataClient is currently in preview, and is not + Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1790,7 +1796,7 @@ def read_modify_write_row( Non-idempotent operation: will not be retried - Warning: BigtableDataClient is currently in preview, and is not + Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 1a09077b1..aaefd9758 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -346,9 +346,7 @@ def transform_from_config(config_dict: dict): for class_dict in config_dict["classes"]: if "concrete_path" in class_dict: class_name = class_dict["path"].rsplit(".", 1)[1] - new_module, new_class_name = class_dict.pop("concrete_path").rsplit(".", 1) - global_text_replacements[class_name] = new_class_name - combined_imports.add(ast.parse(f"from {new_module} import {new_class_name}").body[0]) + global_text_replacements[class_name] = class_dict.pop("concrete_path") # process each class for class_dict in config_dict["classes"]: # convert string class path into class object From 2d7b88f11101c65d140813c185f8edc8cb6800a6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 10 Apr 2024 16:10:16 -0700 Subject: [PATCH 017/360] improved string typing --- google/cloud/bigtable/data/_sync/_autogen.py | 12 ++++++------ sync_surface_generator.py | 13 ++++++++----- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index ef048b2b0..c223e313c 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -48,7 +48,6 @@ _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, ) from google.cloud.bigtable.data._async._read_rows import _ResetRow -from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery @@ -75,6 +74,7 @@ from google.cloud.bigtable.data.row_filters import RowFilterChain from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO +from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( PooledBigtableGrpcAsyncIOTransport, @@ -121,7 +121,7 @@ class _ReadRowsOperation_SyncGen(ABC): def __init__( self, query: ReadRowsQuery, - table: "TableAsync", + table: "google.cloud.bigtable.data._sync.client.Table", operation_timeout: float, attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), @@ -353,8 +353,8 @@ class _MutateRowsOperation_SyncGen(ABC): def __init__( self, - gapic_client: "BigtableAsyncClient", - table: "TableAsync", + gapic_client: "BigtableClient", + table: "google.cloud.bigtable.data._sync.client.Table", mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, @@ -516,7 +516,7 @@ class MutationsBatcher_SyncGen(ABC): def __init__( self, - table: "TableAsync", + table: "google.cloud.bigtable.data._sync.client.Table", *, flush_interval: float | None = 5, flush_limit_mutation_count: int | None = 1000, @@ -925,7 +925,7 @@ def __init__( project=project, client_options=client_options, ) - self._gapic_client = BigtableAsyncClient( + self._gapic_client = BigtableClient( transport=transport_str, credentials=credentials, client_options=client_options, diff --git a/sync_surface_generator.py b/sync_surface_generator.py index aaefd9758..08f3ce7ef 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -108,10 +108,16 @@ def visit_AsyncFunctionDef(self, node): ) # remove pytest.mark.asyncio decorator if hasattr(node, "decorator_list"): + # TODO: make generic is_asyncio_decorator = lambda d: all(x in ast.dump(d) for x in ["pytest", "mark", "asyncio"]) node.decorator_list = [ d for d in node.decorator_list if not is_asyncio_decorator(d) ] + # visit string type annotations + for arg in node.args.args: + if arg.annotation: + if isinstance(arg.annotation, ast.Constant): + arg.annotation.value = self.text_replacements.get(arg.annotation.value, arg.annotation.value) return ast.copy_location( ast.FunctionDef( self.text_replacements.get(node.name, node.name), @@ -124,11 +130,6 @@ def visit_AsyncFunctionDef(self, node): ) def visit_Call(self, node): - # name replacement for class method calls - if isinstance(node.func, ast.Attribute) and isinstance( - node.func.value, ast.Name - ): - node.func.value.id = self.text_replacements.get(node.func.value.id, node.func.value.id) return ast.copy_location( ast.Call( self.visit(node.func), @@ -215,6 +216,7 @@ def visit_ListComp(self, node): ) def visit_Subscript(self, node): + # TODO: generalize? if ( hasattr(node, "value") and isinstance(node.value, ast.Name) @@ -422,6 +424,7 @@ def transform_from_config(config_dict: dict): "AsyncGenerator": "Generator", "StopAsyncIteration": "StopIteration", "Awaitable": None, + "BigtableAsyncClient": "BigtableClient", }, # performs find/replace for these terms in docstrings and generated code "classes": [ { From 2c950f21e724b531ab7e5581422b2b8d5fbc5e39 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 10 Apr 2024 16:52:12 -0700 Subject: [PATCH 018/360] got client and table working --- google/cloud/bigtable/data/_async/client.py | 50 ++++++-- google/cloud/bigtable/data/_sync/_autogen.py | 123 +++++++++---------- google/cloud/bigtable/data/_sync/client.py | 24 +++- sync_surface_generator.py | 8 +- 4 files changed, 124 insertions(+), 81 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index ed14c618d..bfd9d4e84 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -123,9 +123,7 @@ def __init__( - ValueError if pool_size is less than 1 """ # set up transport in registry - transport_str = f"pooled_grpc_asyncio_{pool_size}" - transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) - BigtableClientMeta._transport_registry[transport_str] = transport + transport_str = self._transport_init(pool_size) # set up client info headers for veneer library client_info = DEFAULT_CLIENT_INFO client_info.client_library_version = self._client_version() @@ -172,11 +170,7 @@ def __init__( RuntimeWarning, stacklevel=2, ) - self.transport._grpc_channel = PooledChannel( - pool_size=pool_size, - host=self._emulator_host, - insecure=True, - ) + self._prep_emulator_channel(pool_size) # refresh cached stubs to use emulator pool self.transport._stubs = {} self.transport._prep_wrapped_messages(client_info) @@ -192,6 +186,29 @@ def __init__( stacklevel=2, ) + def _transport_init(self, pool_size: int) -> str: + """ + Helper function for intiializing the transport object + + Different implementations for sync vs async client + """ + transport_str = f"pooled_grpc_asyncio_{pool_size}" + transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) + BigtableClientMeta._transport_registry[transport_str] = transport + return transport_str + + def _prep_emulator_channel(self, pool_size:int): + """ + Helper function for initializing emulator's insecure grpc channel + + Different implementations for sync vs async client + """ + self.transport._grpc_channel = PooledChannel( + pool_size=pool_size, + host=self._emulator_host, + insecure=True, + ) + @staticmethod def _client_version() -> str: """ @@ -539,11 +556,20 @@ def __init__( default_mutate_rows_retryable_errors or () ) self.default_retryable_errors = default_retryable_errors or () + self._register_with_client() + + + def _register_with_client(self): + """ + Calls the client's _register_instance method to warm the grpc channels for this instance + + Different implementations for sync vs async client + """ # raises RuntimeError if called outside of an async context (no running event loop) try: self._register_instance_task = asyncio.create_task( - self.client._register_instance(instance_id, self) + self.client._register_instance(self.instance_id, self) ) except RuntimeError as e: raise RuntimeError( @@ -1241,7 +1267,8 @@ async def close(self): """ Called to close the Table instance and release any resources held by it. """ - self._register_instance_task.cancel() + if self._register_instance_task: + self._register_instance_task.cancel() await self.client._remove_instance_registration(self.instance_id, self) async def __aenter__(self): @@ -1251,7 +1278,8 @@ async def __aenter__(self): Ensure registration task has time to run, so that grpc channels will be warmed for the specified instance """ - await self._register_instance_task + if self._register_instance_task: + await self._register_instance_task return self async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index c223e313c..c87d1ebde 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -31,7 +31,6 @@ import functools import grpc import os -import random import time import warnings @@ -58,6 +57,7 @@ from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data._helpers import _validate_timeouts from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup @@ -75,12 +75,8 @@ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient -from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, -) -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel, +from google.cloud.bigtable_v2.services.bigtable.transports.grpc import ( + BigtableGrpcTransport, ) from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB @@ -903,9 +899,7 @@ def __init__( - RuntimeError if called outside of an async context (no running event loop) - ValueError if pool_size is less than 1 """ - transport_str = f"pooled_grpc_asyncio_{pool_size}" - transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) - BigtableClientMeta._transport_registry[transport_str] = transport + transport_str = self._transport_init(pool_size) client_info = DEFAULT_CLIENT_INFO client_info.client_library_version = self._client_version() if type(client_options) is dict: @@ -931,9 +925,7 @@ def __init__( client_options=client_options, client_info=client_info, ) - self.transport = cast( - PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport - ) + self.transport = cast(BigtableGrpcTransport, self._gapic_client.transport) self._active_instances: Set[_WarmedInstanceKey] = set() self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() @@ -944,9 +936,7 @@ def __init__( RuntimeWarning, stacklevel=2, ) - self.transport._grpc_channel = PooledChannel( - pool_size=pool_size, host=self._emulator_host, insecure=True - ) + self._prep_emulator_channel(pool_size) self.transport._stubs = {} self.transport._prep_wrapped_messages(client_info) else: @@ -959,10 +949,15 @@ def __init__( stacklevel=2, ) + def _transport_init(self, pool_size: int) -> str: + raise NotImplementedError("Function not implemented in sync class") + + def _prep_emulator_channel(self, pool_size: int): + raise NotImplementedError("Function not implemented in sync class") + @staticmethod def _client_version() -> str: - """Helper function to return the client version string for this client""" - return f"{google.cloud.bigtable.__version__}-data-async" + raise NotImplementedError("Function not implemented in sync class") def _start_background_channel_refresh(self) -> None: raise NotImplementedError("Function not implemented in sync class") @@ -985,50 +980,6 @@ def _ping_and_warm_instances( - sequence of results or exceptions from the ping requests """ - def _manage_channel( - self, - channel_idx: int, - refresh_interval_min: float = 60 * 35, - refresh_interval_max: float = 60 * 45, - grace_period: float = 60 * 10, - ) -> None: - """ - Background coroutine that periodically refreshes and warms a grpc channel - - The backend will automatically close channels after 60 minutes, so - `refresh_interval` + `grace_period` should be < 60 minutes - - Runs continuously until the client is closed - - Args: - channel_idx: index of the channel in the transport's channel pool - refresh_interval_min: minimum interval before initiating refresh - process in seconds. Actual interval will be a random value - between `refresh_interval_min` and `refresh_interval_max` - refresh_interval_max: maximum interval before initiating refresh - process in seconds. Actual interval will be a random value - between `refresh_interval_min` and `refresh_interval_max` - grace_period: time to allow previous channel to serve existing - requests before closing, in seconds - """ - first_refresh = self._channel_init_time + random.uniform( - refresh_interval_min, refresh_interval_max - ) - next_sleep = max(first_refresh - time.monotonic(), 0) - if next_sleep > 0: - channel = self.transport.channels[channel_idx] - self._ping_and_warm_instances(channel) - while True: - time.sleep(next_sleep) - new_channel = self.transport.grpc_channel._create_channel() - self._ping_and_warm_instances(new_channel) - start_timestamp = time.time() - self.transport.replace_channel( - channel_idx, grace=grace_period, swap_sleep=10, new_channel=new_channel - ) - next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) - next_sleep = next_refresh - (time.time() - start_timestamp) - def _register_instance( self, instance_id: str, owner: google.cloud.bigtable.data._sync.client.Table ) -> None: @@ -1210,6 +1161,48 @@ def __init__( Raises: - RuntimeError if called outside of an async context (no running event loop) """ + _validate_timeouts( + default_operation_timeout, default_attempt_timeout, allow_none=True + ) + _validate_timeouts( + default_read_rows_operation_timeout, + default_read_rows_attempt_timeout, + allow_none=True, + ) + _validate_timeouts( + default_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout, + allow_none=True, + ) + self.client = client + self.instance_id = instance_id + self.instance_name = self.client._gapic_client.instance_path( + self.client.project, instance_id + ) + self.table_id = table_id + self.table_name = self.client._gapic_client.table_path( + self.client.project, instance_id, table_id + ) + self.app_profile_id = app_profile_id + self.default_operation_timeout = default_operation_timeout + self.default_attempt_timeout = default_attempt_timeout + self.default_read_rows_operation_timeout = default_read_rows_operation_timeout + self.default_read_rows_attempt_timeout = default_read_rows_attempt_timeout + self.default_mutate_rows_operation_timeout = ( + default_mutate_rows_operation_timeout + ) + self.default_mutate_rows_attempt_timeout = default_mutate_rows_attempt_timeout + self.default_read_rows_retryable_errors = ( + default_read_rows_retryable_errors or () + ) + self.default_mutate_rows_retryable_errors = ( + default_mutate_rows_retryable_errors or () + ) + self.default_retryable_errors = default_retryable_errors or () + self._register_with_client() + + def _register_with_client(self): + raise NotImplementedError("Function not implemented in sync class") def read_rows_stream( self, @@ -1835,7 +1828,8 @@ def read_modify_write_row( def close(self): """Called to close the Table instance and release any resources held by it.""" - self._register_instance_task.cancel() + if self._register_instance_task: + self._register_instance_task.cancel() self.client._remove_instance_registration(self.instance_id, self) def __enter__(self): @@ -1845,7 +1839,8 @@ def __enter__(self): Ensure registration task has time to run, so that grpc channels will be warmed for the specified instance """ - self._register_instance_task + if self._register_instance_task: + self._register_instance_task return self def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 2d87f0758..0f94ad7d9 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -16,6 +16,8 @@ from typing import Any +import grpc + import google.auth.credentials from google.cloud.bigtable.data._sync._autogen import BigtableDataClient_SyncGen @@ -34,12 +36,26 @@ def __init__( ): # remove pool size option in sync client super().__init__( - project=project, credentials=credentials, client_options=client_options + project=project, credentials=credentials, client_options=client_options, pool_size=1 ) + def _transport_init(self, pool_size: int) -> str: + return "grpc" + + def _prep_emulator_channel(self, pool_size: int) -> str: + self.transport._grpc_channel = grpc.insecure_channel(host=self._emulator_host) + + @staticmethod + def _client_version() -> str: + return f"{google.cloud.bigtable.__version__}-data" + + def _start_background_channel_refresh(self) -> None: + # TODO: implement channel refresh + pass + class Table(Table_SyncGen): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # register table with client + + def _register_with_client(self): self.client._register_instance(self.instance_id, self) + self._register_instance_task = None diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 08f3ce7ef..4c9996f78 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -412,6 +412,7 @@ def transform_from_config(config_dict: dict): "typing.AsyncGenerator": "typing.Generator", "grpc.aio.Channel": "grpc.Channel", "google.cloud.bigtable_v2.services.bigtable.async_client.BigtableAsyncClient": "google.cloud.bigtable_v2.services.bigtable.client.BigtableClient", + "google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio.PooledBigtableGrpcAsyncIOTransport": "google.cloud.bigtable_v2.services.bigtable.transports.grpc.BigtableGrpcTransport", }, # replace modules with corresponding sync version. "text_replacements": { "__anext__": "__next__", @@ -425,6 +426,7 @@ def transform_from_config(config_dict: dict): "StopAsyncIteration": "StopIteration", "Awaitable": None, "BigtableAsyncClient": "BigtableClient", + "PooledBigtableGrpcAsyncIOTransport": "BigtableGrpcTransport", }, # performs find/replace for these terms in docstrings and generated code "classes": [ { @@ -457,13 +459,15 @@ def transform_from_config(config_dict: dict): "autogen_sync_name": "BigtableDataClient_SyncGen", "concrete_path": "google.cloud.bigtable.data._sync.client.BigtableDataClient", "pass_methods": ["close", "_ping_and_warm_instances"], - "error_methods": ["_start_background_channel_refresh"], + "drop_methods": ["_manage_channel"], + "error_methods": ["_start_background_channel_refresh", "_client_version", "_prep_emulator_channel", "_transport_init"] }, { "path": "google.cloud.bigtable.data._async.client.TableAsync", "autogen_sync_name": "Table_SyncGen", "concrete_path": "google.cloud.bigtable.data._sync.client.Table", - "pass_methods": ["__init__", "read_rows_sharded"], + "pass_methods": ["read_rows_sharded"], + "error_methods": ["_register_with_client"] }, ] } From f73a225f9fbfa70f3281af7b7ae4ddef66edd8d8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 10 Apr 2024 17:03:04 -0700 Subject: [PATCH 019/360] moved config into yaml file --- google/cloud/bigtable/data/_sync/_autogen.py | 45 +---------- .../cloud/bigtable/data/_sync/sync_gen.yaml | 54 ++++++++++++- sync_surface_generator.py | 80 +------------------ 3 files changed, 57 insertions(+), 122 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index c87d1ebde..fa129fff6 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -49,7 +49,6 @@ from google.cloud.bigtable.data._async._read_rows import _ResetRow from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data._helpers import RowKeySamples -from google.cloud.bigtable.data._helpers import ShardedQuery from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import _WarmedInstanceKey from google.cloud.bigtable.data._helpers import _attempt_timeout_generator @@ -179,7 +178,7 @@ def _read_rows_attempt(self) -> Generator[Row, None, "Any"]: return self.merge_rows(chunked_stream) def chunk_stream( - self, stream: Iterable[ReadRowsResponsePB] + self, stream: None[Iterable[ReadRowsResponsePB]] ) -> Generator[ReadRowsResponsePB.CellChunk, None, "Any"]: """process chunks out of raw read_rows stream""" for resp in stream: @@ -1358,48 +1357,6 @@ def read_row( return None return results[0] - def read_rows_sharded( - self, - sharded_query: ShardedQuery, - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - ) -> list[Row]: - """ - Runs a sharded query in parallel, then return the results in a single list. - Results will be returned in the order of the input queries. - - This function is intended to be run on the results on a query.shard() call: - - ``` - table_shard_keys = await table.sample_row_keys() - query = ReadRowsQuery(...) - shard_queries = query.shard(table_shard_keys) - results = await table.read_rows_sharded(shard_queries) - ``` - - Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not - yet recommended for production use. - - Args: - - sharded_query: a sharded query to execute - - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_read_rows_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_read_rows_attempt_timeout. - If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_read_rows_retryable_errors. - Raises: - - ShardedReadRowsExceptionGroup: if any of the queries failed - - ValueError: if the query_list is empty - """ - def row_exists( self, row_key: str | bytes, diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index 88c3e5d1d..fe7447335 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -1,3 +1,53 @@ -async_classes: - - "google.cloud.bigtable._async._read_rows._ReadRowsOperationAsync" +module_replacements: # Replace entire modules + asyncio.sleep: time.sleep + asyncio.Queue: queue.Queue + asyncio.Condition: threading.Condition + asyncio.Future: concurrent.futures.Future + google.api_core.retry_async: google.api_core.retry + typing.AsyncIterable: typing.Iterable + typing.AsyncIterator: typing.Iterator + typing.AsyncGenerator: typing.Generator + grpc.aio.Channel: grpc.Channel + google.cloud.bigtable_v2.services.bigtable.async_client.BigtableAsyncClient: google.cloud.bigtable_v2.services.bigtable.client.BigtableClient + google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio.PooledBigtableGrpcAsyncIOTransport: google.cloud.bigtable_v2.services.bigtable.transports.grpc.BigtableGrpcTransport +text_replacements: # Find and replace specific text patterns + __anext__: __next__ + __aiter__: __iter__ + __aenter__: __enter__ + __aexit__: __exit__ + aclose: close + AsyncIterable: Iterable + AsyncIterator: Iterator + AsyncGenerator: Generator + StopAsyncIteration: StopIteration + Awaitable: None + BigtableAsyncClient: BigtableClient + PooledBigtableGrpcAsyncIOTransport: BigtableGrpcTransport + +classes: # Specify transformations for individual classes + - path: google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync + autogen_sync_name: _ReadRowsOperation_SyncGen + concrete_path: google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation + - path: google.cloud.bigtable.data._async._mutate_rows._MutateRowsOperationAsync + autogen_sync_name: _MutateRowsOperation_SyncGen + concrete_path: google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation + - path: google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync + autogen_sync_name: MutationsBatcher_SyncGen + concrete_path: google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher + pass_methods: ["close", "_wait_for_batch_results"] + error_methods: ["_create_bg_task", "_start_flush_timer"] + - path: google.cloud.bigtable.data._async.mutations_batcher._FlowControlAsync + autogen_sync_name: _FlowControl_SyncGen + concrete_path: google.cloud.bigtable.data._sync.mutations_batcher._FlowControl + - path: google.cloud.bigtable.data._async.client.BigtableDataClientAsync + autogen_sync_name: BigtableDataClient_SyncGen + concrete_path: google.cloud.bigtable.data._sync.client.BigtableDataClient + pass_methods: ["close", "_ping_and_warm_instances"] + drop_methods: ["_manage_channel"] + error_methods: ["_start_background_channel_refresh", "_client_version", "_prep_emulator_channel", "_transport_init"] + - path: google.cloud.bigtable.data._async.client.TableAsync + autogen_sync_name: Table_SyncGen + concrete_path: google.cloud.bigtable.data._sync.client.Table + drop_methods: ["read_rows_sharded"] + error_methods: ["_register_with_client"] diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 4c9996f78..5f6d674f4 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -18,12 +18,9 @@ import inspect import ast import textwrap -import time -import queue -import os -import threading -import concurrent.futures import importlib +import yaml +from pathlib import Path from black import format_str, FileMode import autoflake @@ -400,77 +397,8 @@ def transform_from_config(config_dict: dict): if __name__ == "__main__": - config = { - "module_replacements": { - "asyncio.sleep": "time.sleep", - "asyncio.Queue": "queue.Queue", - "asyncio.Condition": "threading.Condition", - "asyncio.Future": "concurrent.futures.Future", - "google.api_core.retry_async": "google.api_core.retry", - "typing.AsyncIterable": "typing.Iterable", - "typing.AsyncIterator": "typing.Iterator", - "typing.AsyncGenerator": "typing.Generator", - "grpc.aio.Channel": "grpc.Channel", - "google.cloud.bigtable_v2.services.bigtable.async_client.BigtableAsyncClient": "google.cloud.bigtable_v2.services.bigtable.client.BigtableClient", - "google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio.PooledBigtableGrpcAsyncIOTransport": "google.cloud.bigtable_v2.services.bigtable.transports.grpc.BigtableGrpcTransport", - }, # replace modules with corresponding sync version. - "text_replacements": { - "__anext__": "__next__", - "__aiter__": "__iter__", - "__aenter__": "__enter__", - "__aexit__": "__exit__", - "aclose": "close", - "AsyncIterable": "Iterable", - "AsyncIterator": "Iterator", - "AsyncGenerator": "Generator", - "StopAsyncIteration": "StopIteration", - "Awaitable": None, - "BigtableAsyncClient": "BigtableClient", - "PooledBigtableGrpcAsyncIOTransport": "BigtableGrpcTransport", - }, # performs find/replace for these terms in docstrings and generated code - "classes": [ - { - "path": "google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync", - "autogen_sync_name": "_ReadRowsOperation_SyncGen", - "concrete_path": "google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", - "pass_methods": [], # useful if you want to keep docstring - "drop_methods": [], # useful when the function can be completely removed - "error_methods": [], # useful when the implementation and docstring depend heavily on asyncio - }, - { - "path": "google.cloud.bigtable.data._async._mutate_rows._MutateRowsOperationAsync", - "autogen_sync_name": "_MutateRowsOperation_SyncGen", - "concrete_path": "google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", - }, - { - "path": "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync", - "autogen_sync_name": "MutationsBatcher_SyncGen", - "concrete_path": "google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", - "pass_methods": ["close", "_wait_for_batch_results"], - "error_methods": ["_create_bg_task", "_start_flush_timer"] - }, - { - "path": "google.cloud.bigtable.data._async.mutations_batcher._FlowControlAsync", - "autogen_sync_name": "_FlowControl_SyncGen", - "concrete_path": "google.cloud.bigtable.data._sync.mutations_batcher._FlowControl", - }, - { - "path": "google.cloud.bigtable.data._async.client.BigtableDataClientAsync", - "autogen_sync_name": "BigtableDataClient_SyncGen", - "concrete_path": "google.cloud.bigtable.data._sync.client.BigtableDataClient", - "pass_methods": ["close", "_ping_and_warm_instances"], - "drop_methods": ["_manage_channel"], - "error_methods": ["_start_background_channel_refresh", "_client_version", "_prep_emulator_channel", "_transport_init"] - }, - { - "path": "google.cloud.bigtable.data._async.client.TableAsync", - "autogen_sync_name": "Table_SyncGen", - "concrete_path": "google.cloud.bigtable.data._sync.client.Table", - "pass_methods": ["read_rows_sharded"], - "error_methods": ["_register_with_client"] - }, - ] - } + load_path = "./google/cloud/bigtable/data/_sync/sync_gen.yaml" + config = yaml.safe_load(Path(load_path).read_text()) save_path = "./google/cloud/bigtable/data/_sync/_autogen.py" From 92d5608016501ed9d66eaa0bc545687e51187e20 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 10 Apr 2024 17:05:06 -0700 Subject: [PATCH 020/360] moved save_path into yaml --- google/cloud/bigtable/data/_sync/sync_gen.yaml | 2 ++ sync_surface_generator.py | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index fe7447335..145c5d45e 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -51,3 +51,5 @@ classes: # Specify transformations for individual classes concrete_path: google.cloud.bigtable.data._sync.client.Table drop_methods: ["read_rows_sharded"] error_methods: ["_register_with_client"] + +save_path: "google/cloud/bigtable/data/_sync/_autogen.py" diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 5f6d674f4..50986ac4f 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -400,8 +400,7 @@ def transform_from_config(config_dict: dict): load_path = "./google/cloud/bigtable/data/_sync/sync_gen.yaml" config = yaml.safe_load(Path(load_path).read_text()) - save_path = "./google/cloud/bigtable/data/_sync/_autogen.py" - + save_path = config.get("save_path") code = transform_from_config(config) if save_path is not None: From 7b40f279c47240f97ee9c258205e4582fa788545 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Apr 2024 14:04:03 -0700 Subject: [PATCH 021/360] fixed sync emulator channel --- google/cloud/bigtable/data/_sync/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 0f94ad7d9..d95f82107 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -43,7 +43,7 @@ def _transport_init(self, pool_size: int) -> str: return "grpc" def _prep_emulator_channel(self, pool_size: int) -> str: - self.transport._grpc_channel = grpc.insecure_channel(host=self._emulator_host) + self.transport._grpc_channel = grpc.insecure_channel(target=self._emulator_host) @staticmethod def _client_version() -> str: From b3546d80dc3a653e35c7c254cd70a0233d0d42f0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Apr 2024 14:06:19 -0700 Subject: [PATCH 022/360] moved acceptance tests to async folder --- tests/unit/data/{ => _async}/test_read_rows_acceptance.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/unit/data/{ => _async}/test_read_rows_acceptance.py (100%) diff --git a/tests/unit/data/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py similarity index 100% rename from tests/unit/data/test_read_rows_acceptance.py rename to tests/unit/data/_async/test_read_rows_acceptance.py From cbec8f2e818ebc498f34fa53b64fb6997f0b7e2a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Apr 2024 14:33:40 -0700 Subject: [PATCH 023/360] broke module_replacements into asyncio and imports --- google/cloud/bigtable/data/_async/client.py | 4 +- google/cloud/bigtable/data/_sync/_autogen.py | 7 ++- .../cloud/bigtable/data/_sync/sync_gen.yaml | 24 +++++----- sync_surface_generator.py | 47 ++++++++----------- 4 files changed, 36 insertions(+), 46 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index bfd9d4e84..bd9e8c8c0 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -26,7 +26,6 @@ ) import asyncio -import grpc import time import warnings import sys @@ -34,6 +33,7 @@ import os from functools import partial +from grpc.aio import Channel as AsyncChannel from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient @@ -246,7 +246,7 @@ async def close(self, timeout: float = 2.0): self._channel_refresh_tasks = [] async def _ping_and_warm_instances( - self, channel: grpc.aio.Channel, instance_key: _WarmedInstanceKey | None = None + self, channel: AsyncChannel, instance_key: _WarmedInstanceKey | None = None ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index fa129fff6..6bfa2c164 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -19,9 +19,9 @@ from abc import ABC from collections import deque from functools import partial +from grpc import Channel from typing import Any -from typing import Generator -from typing import Iterable +from typing import Generator, Iterable from typing import Optional from typing import Sequence from typing import Set @@ -29,7 +29,6 @@ import asyncio import atexit import functools -import grpc import os import time import warnings @@ -965,7 +964,7 @@ def close(self, timeout: float = 2.0): """Cancel all background tasks""" def _ping_and_warm_instances( - self, channel: grpc.Channel, instance_key: _WarmedInstanceKey | None = None + self, channel: Channel, instance_key: _WarmedInstanceKey | None = None ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index 145c5d45e..b48030ccf 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -1,15 +1,8 @@ -module_replacements: # Replace entire modules - asyncio.sleep: time.sleep - asyncio.Queue: queue.Queue - asyncio.Condition: threading.Condition - asyncio.Future: concurrent.futures.Future - google.api_core.retry_async: google.api_core.retry - typing.AsyncIterable: typing.Iterable - typing.AsyncIterator: typing.Iterator - typing.AsyncGenerator: typing.Generator - grpc.aio.Channel: grpc.Channel - google.cloud.bigtable_v2.services.bigtable.async_client.BigtableAsyncClient: google.cloud.bigtable_v2.services.bigtable.client.BigtableClient - google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio.PooledBigtableGrpcAsyncIOTransport: google.cloud.bigtable_v2.services.bigtable.transports.grpc.BigtableGrpcTransport +asyncio_replacements: # Replace asyncio functionaility + sleep: time.sleep + Queue: queue.Queue + Condition: threading.Condition + Future: concurrent.futures.Future text_replacements: # Find and replace specific text patterns __anext__: __next__ @@ -24,6 +17,13 @@ text_replacements: # Find and replace specific text patterns Awaitable: None BigtableAsyncClient: BigtableClient PooledBigtableGrpcAsyncIOTransport: BigtableGrpcTransport + AsyncChannel: Channel + +added_imports: + - "from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient" + - "from google.cloud.bigtable_v2.services.bigtable.transports.grpc import BigtableGrpcTransport" + - "from typing import Generator, Iterable, Iterator" + - "from grpc import Channel" classes: # Specify transformations for individual classes - path: google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 50986ac4f..3f0e47f34 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -36,18 +36,18 @@ class AsyncToSyncTransformer(ast.NodeTransformer): outside of this autogeneration system are always applied """ - def __init__(self, *, name=None, module_replacements=None, text_replacements=None, drop_methods=None, pass_methods=None, error_methods=None): + def __init__(self, *, name=None, asyncio_replacements=None, text_replacements=None, drop_methods=None, pass_methods=None, error_methods=None): """ Args: - name: the name of the class being processed. Just used in exceptions - - module_replacements: modules to replace + - asyncio_replacements: asyncio functionality to replace - text_replacements: dict of text to replace directly in the source code and docstrings - drop_methods: list of method names to drop from the class - pass_methods: list of method names to replace with "pass" in the class - error_methods: list of method names to replace with "raise NotImplementedError" in the class """ self.name = name - self.module_replacements = module_replacements or {} + self.asyncio_replacements = asyncio_replacements or {} self.text_replacements = text_replacements or {} self.drop_methods = drop_methods or [] self.pass_methods = pass_methods or [] @@ -98,7 +98,7 @@ def visit_AsyncFunctionDef(self, node): and isinstance(n.func, ast.Attribute) \ and isinstance(n.func.value, ast.Name) \ and n.func.value.id == "asyncio" \ - and f"asyncio.{n.func.attr}" not in self.module_replacements: + and n.func.attr not in self.asyncio_replacements: path_str = f"{self.name}.{node.name}" if self.name else node.name raise RuntimeError( f"{path_str} contains unhandled asyncio calls: {n.func.attr}. Add method to drop_methods, pass_methods, or error_methods to handle safely." @@ -140,20 +140,15 @@ def visit_Await(self, node): return self.visit(node.value) def visit_Attribute(self, node): - parts = [] - attr_node = node - while isinstance(attr_node, ast.Attribute): - parts.append(attr_node.attr) - attr_node = attr_node.value - if isinstance(attr_node, ast.Name): - parts.append(attr_node.id) - full_name = ".".join(parts[::-1]) - - if full_name in self.module_replacements: - # replace from module_replacements - replacement = self.module_replacements[full_name] + if ( + isinstance(node.value, ast.Name) + and isinstance(node.value.ctx, ast.Load) + and node.value.id == "asyncio" + and node.attr in self.asyncio_replacements + ): + replacement = self.asyncio_replacements[node.attr] return ast.copy_location(ast.parse(replacement, mode="eval").body, node) - elif node.attr in self.text_replacements: + if node.attr in self.text_replacements: # replace from text_replacements new_node = ast.copy_location( ast.Attribute(self.visit(node.value), self.text_replacements[node.attr], node.ctx), node @@ -266,7 +261,7 @@ def _create_error_node(node, error_msg): def get_imports(self, filename): """ - Get the imports from a file, and do a find-and-replace against module_replacements + Get the imports from a file, and do a find-and-replace against asyncio_replacements """ imports = set() with open(filename, "r") as f: @@ -276,14 +271,14 @@ def get_imports(self, filename): for alias in node.names: if isinstance(node, ast.Import): # import statments - new_import = self.module_replacements.get(alias.name, alias.name) + new_import = self.asyncio_replacements.get(alias.name, alias.name) imports.add(ast.parse(f"import {new_import}").body[0]) else: # import from statements # break into individual components full_path = f"{node.module}.{alias.name}" - if full_path in self.module_replacements: - full_path = self.module_replacements[full_path] + if full_path in self.asyncio_replacements: + full_path = self.asyncio_replacements[full_path] module, name = full_path.rsplit(".", 1) # don't import from same file if module == ".": @@ -320,9 +315,6 @@ def transform_class(in_obj: Type, **kwargs): # find imports imports = transformer.get_imports(filename) imports.add(ast.parse("from abc import ABC").body[0]) - # add globals - # for g in transformer.globals: - # imports.add(ast.parse(f"import {g}").body[0]) # add locals from file, in case they are needed if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): with open(filename, "r") as f: @@ -352,7 +344,7 @@ def transform_from_config(config_dict: dict): module_path, class_name = class_dict.pop("path").rsplit(".", 1) class_object = getattr(importlib.import_module(module_path), class_name) # add globals to class_dict - class_dict["module_replacements"] = {**config_dict.get("module_replacements", {}), **class_dict.get("module_replacements", {})} + class_dict["asyncio_replacements"] = {**config_dict.get("asyncio_replacements", {}), **class_dict.get("asyncio_replacements", {})} class_dict["text_replacements"] = {**global_text_replacements, **class_dict.get("text_replacements", {})} # transform class tree_body, imports = transform_class(class_object, **class_dict) @@ -360,9 +352,8 @@ def transform_from_config(config_dict: dict): combined_tree.body.extend(tree_body) combined_imports.update(imports) # add extra imports - # if add_imports: - # for import_str in add_imports: - # combined_imports.add(ast.parse(import_str).body[0]) + for import_str in config_dict.get("added_imports", []): + combined_imports.add(ast.parse(import_str).body[0]) # render tree as string of code import_unique = list(set([ast.unparse(i) for i in combined_imports])) import_unique.sort() From 1af60be9b02e9e0d511c199b38a8a204cc2a491c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Apr 2024 14:39:22 -0700 Subject: [PATCH 024/360] update retries_async --- google/cloud/bigtable/data/_sync/_autogen.py | 8 ++++---- google/cloud/bigtable/data/_sync/sync_gen.yaml | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 6bfa2c164..2b8837e61 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -140,7 +140,7 @@ def __init__( def start_operation(self) -> Generator[Row, None, "Any"]: """Start the read_rows operation, retrying on retryable errors.""" - return retries.retry_target_stream_async( + return retries.retry_target_stream( self._read_rows_attempt, self._predicate, exponential_sleep_generator(0.01, 60, multiplier=2), @@ -380,7 +380,7 @@ def __init__( *retryable_exceptions, bt_exceptions._MutateRowsIncomplete ) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - self._operation = retries.retry_target_async( + self._operation = retries.retry_target( self._run_attempt, self.is_retryable, sleep_generator, @@ -1468,7 +1468,7 @@ def execute_rpc(): ) return [(s.row_key, s.offset_bytes) for s in results] - return retries.retry_target_async( + return retries.retry_target( execute_rpc, predicate, sleep_generator, @@ -1595,7 +1595,7 @@ def mutate_row( metadata=_make_metadata(self.table_name, self.app_profile_id), retry=None, ) - return retries.retry_target_async( + return retries.retry_target( target, predicate, sleep_generator, diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index b48030ccf..82b96df6a 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -18,6 +18,8 @@ text_replacements: # Find and replace specific text patterns BigtableAsyncClient: BigtableClient PooledBigtableGrpcAsyncIOTransport: BigtableGrpcTransport AsyncChannel: Channel + retry_target_async: retry_target + retry_target_stream_async: retry_target_stream added_imports: - "from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient" From e486ad57735e2134a8b4a5c762497489dd3c663b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Apr 2024 14:43:06 -0700 Subject: [PATCH 025/360] replaced asyncio.Task --- google/cloud/bigtable/data/_sync/_autogen.py | 3 +-- google/cloud/bigtable/data/_sync/sync_gen.yaml | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 2b8837e61..3e5340ac0 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -26,7 +26,6 @@ from typing import Sequence from typing import Set from typing import cast -import asyncio import atexit import functools import os @@ -927,7 +926,7 @@ def __init__( self._active_instances: Set[_WarmedInstanceKey] = set() self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() - self._channel_refresh_tasks: list[asyncio.Task[None]] = [] + self._channel_refresh_tasks: list[threading.Thread[None]] = [] if self._emulator_host is not None: warnings.warn( "Connecting to Bigtable emulator at {}".format(self._emulator_host), diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index 82b96df6a..0fbd71af1 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -3,6 +3,7 @@ asyncio_replacements: # Replace asyncio functionaility Queue: queue.Queue Condition: threading.Condition Future: concurrent.futures.Future + Task: threading.Thread text_replacements: # Find and replace specific text patterns __anext__: __next__ From 1974a50674edfd00cdd762e428ef20e4d412be90 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Apr 2024 14:48:34 -0700 Subject: [PATCH 026/360] removed use of async generator --- .../cloud/bigtable/data/_async/_read_rows.py | 9 ++++---- google/cloud/bigtable/data/_sync/_autogen.py | 10 ++++---- .../cloud/bigtable/data/_sync/sync_gen.yaml | 1 - sync_surface_generator.py | 23 ------------------- 4 files changed, 9 insertions(+), 34 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 9e0fd78e1..57f150c7d 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -17,7 +17,6 @@ from typing import ( TYPE_CHECKING, - AsyncGenerator, AsyncIterable, Awaitable, Sequence, @@ -101,7 +100,7 @@ def __init__( self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None - def start_operation(self) -> AsyncGenerator[Row, None]: + def start_operation(self) -> AsyncIterable[Row]: """ Start the read_rows operation, retrying on retryable errors. """ @@ -113,7 +112,7 @@ def start_operation(self) -> AsyncGenerator[Row, None]: exception_factory=_retry_exception_factory, ) - def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: + def _read_rows_attempt(self) -> AsyncIterable[Row]: """ Attempt a single read_rows rpc call. This function is intended to be wrapped by retry logic, @@ -148,7 +147,7 @@ def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: async def chunk_stream( self, stream: Awaitable[AsyncIterable[ReadRowsResponsePB]] - ) -> AsyncGenerator[ReadRowsResponsePB.CellChunk, None]: + ) -> AsyncIterable[ReadRowsResponsePB.CellChunk]: """ process chunks out of raw read_rows stream """ @@ -194,7 +193,7 @@ async def chunk_stream( @staticmethod async def merge_rows( - chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None + chunks: AsyncIterable[ReadRowsResponsePB.CellChunk] | None ): """ Merge chunks into rows diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 3e5340ac0..234957961 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -21,7 +21,7 @@ from functools import partial from grpc import Channel from typing import Any -from typing import Generator, Iterable +from typing import Iterable from typing import Optional from typing import Sequence from typing import Set @@ -137,7 +137,7 @@ def __init__( self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None - def start_operation(self) -> Generator[Row, None, "Any"]: + def start_operation(self) -> Iterable[Row]: """Start the read_rows operation, retrying on retryable errors.""" return retries.retry_target_stream( self._read_rows_attempt, @@ -147,7 +147,7 @@ def start_operation(self) -> Generator[Row, None, "Any"]: exception_factory=_retry_exception_factory, ) - def _read_rows_attempt(self) -> Generator[Row, None, "Any"]: + def _read_rows_attempt(self) -> Iterable[Row]: """ Attempt a single read_rows rpc call. This function is intended to be wrapped by retry logic, @@ -177,7 +177,7 @@ def _read_rows_attempt(self) -> Generator[Row, None, "Any"]: def chunk_stream( self, stream: None[Iterable[ReadRowsResponsePB]] - ) -> Generator[ReadRowsResponsePB.CellChunk, None, "Any"]: + ) -> Iterable[ReadRowsResponsePB.CellChunk]: """process chunks out of raw read_rows stream""" for resp in stream: resp = resp._pb @@ -211,7 +211,7 @@ def chunk_stream( current_key = None @staticmethod - def merge_rows(chunks: Generator[ReadRowsResponsePB.CellChunk, None, "Any"] | None): + def merge_rows(chunks: Iterable[ReadRowsResponsePB.CellChunk] | None): """Merge chunks into rows""" if chunks is None: return diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index 0fbd71af1..6f3fddaab 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -13,7 +13,6 @@ text_replacements: # Find and replace specific text patterns aclose: close AsyncIterable: Iterable AsyncIterator: Iterator - AsyncGenerator: Generator StopAsyncIteration: StopIteration Awaitable: None BigtableAsyncClient: BigtableClient diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 3f0e47f34..b188d7c9b 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -208,30 +208,7 @@ def visit_ListComp(self, node): ) def visit_Subscript(self, node): - # TODO: generalize? if ( - hasattr(node, "value") - and isinstance(node.value, ast.Name) - and node.value.id == "AsyncGenerator" - and self.text_replacements.get(node.value.id, "") == "Generator" - ): - # Generator has different argument signature than AsyncGenerator - return ast.copy_location( - ast.Subscript( - ast.Name("Generator"), - ast.Index( - ast.Tuple( - [ - self.visit(i) - for i in node.slice.elts + [ast.Constant("Any")] - ] - ) - ), - node.ctx, - ), - node, - ) - elif ( hasattr(node, "value") and isinstance(node.value, ast.Name) and self.text_replacements.get(node.value.id, False) is None From 64ec3d586336d8a80e95c73a292f60f999521e3b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Apr 2024 14:54:00 -0700 Subject: [PATCH 027/360] fixed type issue --- google/cloud/bigtable/data/_async/client.py | 6 +++--- google/cloud/bigtable/data/_sync/_autogen.py | 4 ++-- google/cloud/bigtable/data/_sync/client.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index bd9e8c8c0..26e439e43 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -170,7 +170,7 @@ def __init__( RuntimeWarning, stacklevel=2, ) - self._prep_emulator_channel(pool_size) + self._prep_emulator_channel(self._emulator_host, pool_size) # refresh cached stubs to use emulator pool self.transport._stubs = {} self.transport._prep_wrapped_messages(client_info) @@ -197,7 +197,7 @@ def _transport_init(self, pool_size: int) -> str: BigtableClientMeta._transport_registry[transport_str] = transport return transport_str - def _prep_emulator_channel(self, pool_size:int): + def _prep_emulator_channel(self, host:str, pool_size:int): """ Helper function for initializing emulator's insecure grpc channel @@ -205,7 +205,7 @@ def _prep_emulator_channel(self, pool_size:int): """ self.transport._grpc_channel = PooledChannel( pool_size=pool_size, - host=self._emulator_host, + host=host, insecure=True, ) diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 234957961..7ab75b360 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -933,7 +933,7 @@ def __init__( RuntimeWarning, stacklevel=2, ) - self._prep_emulator_channel(pool_size) + self._prep_emulator_channel(self._emulator_host, pool_size) self.transport._stubs = {} self.transport._prep_wrapped_messages(client_info) else: @@ -949,7 +949,7 @@ def __init__( def _transport_init(self, pool_size: int) -> str: raise NotImplementedError("Function not implemented in sync class") - def _prep_emulator_channel(self, pool_size: int): + def _prep_emulator_channel(self, host: str, pool_size: int): raise NotImplementedError("Function not implemented in sync class") @staticmethod diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index d95f82107..0bb6b9586 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -42,8 +42,8 @@ def __init__( def _transport_init(self, pool_size: int) -> str: return "grpc" - def _prep_emulator_channel(self, pool_size: int) -> str: - self.transport._grpc_channel = grpc.insecure_channel(target=self._emulator_host) + def _prep_emulator_channel(self, host:str, pool_size: int) -> str: + self.transport._grpc_channel = grpc.insecure_channel(target=host) @staticmethod def _client_version() -> str: From f67ebc476993536daf9868f562566de82bfcae84 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Apr 2024 15:43:12 -0700 Subject: [PATCH 028/360] got _mutate_rows tests passing --- .../bigtable/data/_async/_mutate_rows.py | 6 +- google/cloud/bigtable/data/_sync/_autogen.py | 5 +- .../cloud/bigtable/data/_sync/sync_gen.yaml | 3 +- sync_surface_generator.py | 26 +- tests/unit/data/_async/test__mutate_rows.py | 47 ++- tests/unit/data/_sync/test_autogen.py | 324 ++++++++++++++++++ 6 files changed, 373 insertions(+), 38 deletions(-) create mode 100644 tests/unit/data/_sync/test_autogen.py diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 7d1144553..aed14d338 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -101,7 +101,9 @@ def __init__( bt_exceptions._MutateRowsIncomplete, ) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - self._operation = retries.retry_target_async( + # Note: _operation could be a raw coroutine, but using a lambda + # wrapper helps unify with sync code + self._operation = lambda: retries.retry_target_async( self._run_attempt, self.is_retryable, sleep_generator, @@ -125,7 +127,7 @@ async def start(self): """ try: # trigger mutate_rows - await self._operation + await self._operation() except Exception as exc: # exceptions raised by retryable are added to the list of exceptions for all unfinalized mutations incomplete_indices = self.remaining_indices.copy() diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 7ab75b360..795c5d886 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -84,6 +84,7 @@ import google.auth._default import google.auth.credentials import google.cloud.bigtable.data.exceptions +import google.cloud.bigtable.data.exceptions as bt_exceptions import google.cloud.bigtable_v2.types.bigtable @@ -379,7 +380,7 @@ def __init__( *retryable_exceptions, bt_exceptions._MutateRowsIncomplete ) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - self._operation = retries.retry_target( + self._operation = lambda: retries.retry_target( self._run_attempt, self.is_retryable, sleep_generator, @@ -401,7 +402,7 @@ def start(self): - MutationsExceptionGroup: if any mutations failed """ try: - self._operation + self._operation() except Exception as exc: incomplete_indices = self.remaining_indices.copy() for idx in incomplete_indices: diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index 6f3fddaab..4915058ea 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -26,6 +26,7 @@ added_imports: - "from google.cloud.bigtable_v2.services.bigtable.transports.grpc import BigtableGrpcTransport" - "from typing import Generator, Iterable, Iterator" - "from grpc import Channel" + - "import google.cloud.bigtable.data.exceptions as bt_exceptions" classes: # Specify transformations for individual classes - path: google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync @@ -38,7 +39,7 @@ classes: # Specify transformations for individual classes autogen_sync_name: MutationsBatcher_SyncGen concrete_path: google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher pass_methods: ["close", "_wait_for_batch_results"] - error_methods: ["_create_bg_task", "_start_flush_timer"] + error_methods: ["_create_bg_task", "_start_flush_timer"] - path: google.cloud.bigtable.data._async.mutations_batcher._FlowControlAsync autogen_sync_name: _FlowControl_SyncGen concrete_path: google.cloud.bigtable.data._sync.mutations_batcher._FlowControl diff --git a/sync_surface_generator.py b/sync_surface_generator.py index b188d7c9b..41f248662 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -36,7 +36,7 @@ class AsyncToSyncTransformer(ast.NodeTransformer): outside of this autogeneration system are always applied """ - def __init__(self, *, name=None, asyncio_replacements=None, text_replacements=None, drop_methods=None, pass_methods=None, error_methods=None): + def __init__(self, *, name=None, asyncio_replacements=None, text_replacements=None, drop_methods=None, pass_methods=None, error_methods=None, replace_methods=None): """ Args: - name: the name of the class being processed. Just used in exceptions @@ -45,6 +45,7 @@ def __init__(self, *, name=None, asyncio_replacements=None, text_replacements=No - drop_methods: list of method names to drop from the class - pass_methods: list of method names to replace with "pass" in the class - error_methods: list of method names to replace with "raise NotImplementedError" in the class + - replace_methods: dict of method names to replace with custom code """ self.name = name self.asyncio_replacements = asyncio_replacements or {} @@ -52,6 +53,7 @@ def __init__(self, *, name=None, asyncio_replacements=None, text_replacements=No self.drop_methods = drop_methods or [] self.pass_methods = pass_methods or [] self.error_methods = error_methods or [] + self.replace_methods = replace_methods or {} def update_docstring(self, docstring): """ @@ -90,6 +92,14 @@ def visit_AsyncFunctionDef(self, node): node.body = [ast.Expr(value=ast.Str(s=docstring))] elif node.name in self.error_methods: self._create_error_node(node, "Function not implemented in sync class") + elif node.name in self.replace_methods: + # replace function body with custom code + new_body = [] + for line in self.replace_methods[node.name].split("\n"): + parsed = ast.parse(line) + if len(parsed.body) > 0: + new_body.append(parsed.body[0]) + node.body = new_body else: # check if the function contains non-replaced usage of asyncio func_ast = ast.parse(ast.unparse(node)) @@ -365,14 +375,14 @@ def transform_from_config(config_dict: dict): if __name__ == "__main__": - load_path = "./google/cloud/bigtable/data/_sync/sync_gen.yaml" - config = yaml.safe_load(Path(load_path).read_text()) + for load_path in ["./google/cloud/bigtable/data/_sync/sync_gen.yaml", "./google/cloud/bigtable/data/_sync/unit_tests.yaml"]: + config = yaml.safe_load(Path(load_path).read_text()) - save_path = config.get("save_path") - code = transform_from_config(config) + save_path = config.get("save_path") + code = transform_from_config(config) - if save_path is not None: - with open(save_path, "w") as f: - f.write(code) + if save_path is not None: + with open(save_path, "w") as f: + f.write(code) diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index e03028c45..81151f9b6 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -27,12 +27,6 @@ from mock import AsyncMock # type: ignore -def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation - class TestMutateRowsOperation: def _target_class(self): @@ -52,6 +46,12 @@ def _make_one(self, *args, **kwargs): kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) return self._target_class()(*args, **kwargs) + def _make_mutation(self, count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + async def _mock_stream(self, mutation_list, error_dict): for idx, entry in enumerate(mutation_list): code = error_dict.get(idx, 0) @@ -83,7 +83,7 @@ def test_ctor(self): client = mock.Mock() table = mock.Mock() - entries = [_make_mutation(), _make_mutation()] + entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 attempt_timeout = 0.01 retryable_exceptions = () @@ -136,7 +136,7 @@ def test_ctor_too_many_entries(self): client = mock.Mock() table = mock.Mock() - entries = [_make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + entries = [self._make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT operation_timeout = 0.05 attempt_timeout = 0.01 # no errors if at limit @@ -146,7 +146,7 @@ def test_ctor_too_many_entries(self): self._make_one( client, table, - entries + [_make_mutation()], + entries + [self._make_mutation()], operation_timeout, attempt_timeout, ) @@ -162,7 +162,7 @@ async def test_mutate_rows_operation(self): """ client = mock.Mock() table = mock.Mock() - entries = [_make_mutation(), _make_mutation()] + entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 cls = self._target_class() with mock.patch( @@ -184,7 +184,7 @@ async def test_mutate_rows_attempt_exception(self, exc_type): """ client = AsyncMock() table = mock.Mock() - entries = [_make_mutation(), _make_mutation()] + entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 expected_exception = exc_type("test") client.mutate_rows.side_effect = expected_exception @@ -215,7 +215,7 @@ async def test_mutate_rows_exception(self, exc_type): client = mock.Mock() table = mock.Mock() - entries = [_make_mutation(), _make_mutation()] + entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 expected_cause = exc_type("abort") with mock.patch.object( @@ -248,18 +248,15 @@ async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): """ If an exception fails but eventually passes, it should not raise an exception """ - from google.cloud.bigtable.data._async._mutate_rows import ( - _MutateRowsOperationAsync, - ) client = mock.Mock() table = mock.Mock() - entries = [_make_mutation()] + entries = [self._make_mutation()] operation_timeout = 1 expected_cause = exc_type("retry") num_retries = 2 with mock.patch.object( - _MutateRowsOperationAsync, + self._target_class(), "_run_attempt", AsyncMock(), ) as attempt_mock: @@ -286,7 +283,7 @@ async def test_mutate_rows_incomplete_ignored(self): client = mock.Mock() table = mock.Mock() - entries = [_make_mutation()] + entries = [self._make_mutation()] operation_timeout = 0.05 with mock.patch.object( self._target_class(), @@ -309,7 +306,7 @@ async def test_mutate_rows_incomplete_ignored(self): @pytest.mark.asyncio async def test_run_attempt_single_entry_success(self): """Test mutating a single entry""" - mutation = _make_mutation() + mutation = self._make_mutation() expected_timeout = 1.3 mock_gapic_fn = self._make_mock_gapic({0: mutation}) instance = self._make_one( @@ -339,9 +336,9 @@ async def test_run_attempt_partial_success_retryable(self): """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - success_mutation = _make_mutation() - success_mutation_2 = _make_mutation() - failure_mutation = _make_mutation() + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() mutations = [success_mutation, failure_mutation, success_mutation_2] mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) instance = self._make_one( @@ -360,9 +357,9 @@ async def test_run_attempt_partial_success_retryable(self): @pytest.mark.asyncio async def test_run_attempt_partial_success_non_retryable(self): """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" - success_mutation = _make_mutation() - success_mutation_2 = _make_mutation() - failure_mutation = _make_mutation() + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() mutations = [success_mutation, failure_mutation, success_mutation_2] mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) instance = self._make_one( diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py new file mode 100644 index 000000000..3083da925 --- /dev/null +++ b/tests/unit/data/_sync/test_autogen.py @@ -0,0 +1,324 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by sync_surface_generator.py. Do not edit. + + +from abc import ABC +from unittest import mock +import mock +import pytest + +from google.cloud.bigtable_v2.types import MutateRowsResponse +from google.rpc import status_pb2 +import google.api_core.exceptions as core_exceptions + + +class TestMutateRowsOperation_SyncGen(ABC): + def _target_class(self): + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation + + return _MutateRowsOperation + + def _make_one(self, *args, **kwargs): + if not args: + kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) + kwargs["table"] = kwargs.pop("table", mock.Mock()) + kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) + kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) + kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) + kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) + return self._target_class()(*args, **kwargs) + + def _make_mutation(self, count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def _mock_stream(self, mutation_list, error_dict): + for idx, entry in enumerate(mutation_list): + code = error_dict.get(idx, 0) + yield MutateRowsResponse( + entries=[ + MutateRowsResponse.Entry( + index=idx, status=status_pb2.Status(code=code) + ) + ] + ) + + def _make_mock_gapic(self, mutation_list, error_dict=None): + mock_fn = mock.Mock() + if error_dict is None: + error_dict = {} + mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( + mutation_list, error_dict + ) + return mock_fn + + def test_ctor(self): + """test that constructor sets all the attributes correctly""" + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + from google.api_core.exceptions import DeadlineExceeded + from google.api_core.exceptions import Aborted + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + attempt_timeout = 0.01 + retryable_exceptions = () + instance = self._make_one( + client, + table, + entries, + operation_timeout, + attempt_timeout, + retryable_exceptions, + ) + assert client.mutate_rows.call_count == 0 + instance._gapic_fn() + assert client.mutate_rows.call_count == 1 + inner_kwargs = client.mutate_rows.call_args[1] + assert len(inner_kwargs) == 4 + assert inner_kwargs["table_name"] == table.table_name + assert inner_kwargs["app_profile_id"] == table.app_profile_id + assert inner_kwargs["retry"] is None + metadata = inner_kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert str(table.table_name) in metadata[0][1] + assert str(table.app_profile_id) in metadata[0][1] + entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries] + assert instance.mutations == entries_w_pb + assert next(instance.timeout_generator) == attempt_timeout + assert instance.is_retryable is not None + assert instance.is_retryable(DeadlineExceeded("")) is False + assert instance.is_retryable(Aborted("")) is False + assert instance.is_retryable(_MutateRowsIncomplete("")) is True + assert instance.is_retryable(RuntimeError("")) is False + assert instance.remaining_indices == list(range(len(entries))) + assert instance.errors == {} + + def test_ctor_too_many_entries(self): + """should raise an error if an operation is created with more than 100,000 entries""" + from google.cloud.bigtable.data._async._mutate_rows import ( + _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, + ) + + assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100000 + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + operation_timeout = 0.05 + attempt_timeout = 0.01 + self._make_one(client, table, entries, operation_timeout, attempt_timeout) + with pytest.raises(ValueError) as e: + self._make_one( + client, + table, + entries + [self._make_mutation()], + operation_timeout, + attempt_timeout, + ) + assert "mutate_rows requests can contain at most 100000 mutations" in str( + e.value + ) + assert "Found 100001" in str(e.value) + + def test_mutate_rows_operation(self): + """Test successful case of mutate_rows_operation""" + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + cls = self._target_class() + with mock.patch( + f"{cls.__module__}.{cls.__name__}._run_attempt", mock.Mock() + ) as attempt_mock: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + assert attempt_mock.call_count == 1 + + @pytest.mark.parametrize( + "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] + ) + def test_mutate_rows_attempt_exception(self, exc_type): + """exceptions raised from attempt should be raised in MutationsExceptionGroup""" + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + expected_exception = exc_type("test") + client.mutate_rows.side_effect = expected_exception + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance._run_attempt() + except Exception as e: + found_exc = e + assert client.mutate_rows.call_count == 1 + assert type(found_exc) is exc_type + assert found_exc == expected_exception + assert len(instance.errors) == 2 + assert len(instance.remaining_indices) == 0 + + @pytest.mark.parametrize( + "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] + ) + def test_mutate_rows_exception(self, exc_type): + """exceptions raised from retryable should be raised in MutationsExceptionGroup""" + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + expected_cause = exc_type("abort") + with mock.patch.object( + self._target_class(), "_run_attempt", mock.Mock() + ) as attempt_mock: + attempt_mock.side_effect = expected_cause + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + except MutationsExceptionGroup as e: + found_exc = e + assert attempt_mock.call_count == 1 + assert len(found_exc.exceptions) == 2 + assert isinstance(found_exc.exceptions[0], FailedMutationEntryError) + assert isinstance(found_exc.exceptions[1], FailedMutationEntryError) + assert found_exc.exceptions[0].__cause__ == expected_cause + assert found_exc.exceptions[1].__cause__ == expected_cause + + @pytest.mark.parametrize( + "exc_type", [core_exceptions.DeadlineExceeded, RuntimeError] + ) + def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): + """If an exception fails but eventually passes, it should not raise an exception""" + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] + operation_timeout = 1 + expected_cause = exc_type("retry") + num_retries = 2 + with mock.patch.object( + self._target_class(), "_run_attempt", mock.Mock() + ) as attempt_mock: + attempt_mock.side_effect = [expected_cause] * num_retries + [None] + instance = self._make_one( + client, + table, + entries, + operation_timeout, + operation_timeout, + retryable_exceptions=(exc_type,), + ) + instance.start() + assert attempt_mock.call_count == num_retries + 1 + + def test_mutate_rows_incomplete_ignored(self): + """MutateRowsIncomplete exceptions should not be added to error list""" + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] + operation_timeout = 0.05 + with mock.patch.object( + self._target_class(), "_run_attempt", mock.Mock() + ) as attempt_mock: + attempt_mock.side_effect = _MutateRowsIncomplete("ignored") + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + except MutationsExceptionGroup as e: + found_exc = e + assert attempt_mock.call_count > 0 + assert len(found_exc.exceptions) == 1 + assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) + + def test_run_attempt_single_entry_success(self): + """Test mutating a single entry""" + mutation = self._make_mutation() + expected_timeout = 1.3 + mock_gapic_fn = self._make_mock_gapic({0: mutation}) + instance = self._make_one( + mutation_entries=[mutation], attempt_timeout=expected_timeout + ) + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + instance._run_attempt() + assert len(instance.remaining_indices) == 0 + assert mock_gapic_fn.call_count == 1 + (_, kwargs) = mock_gapic_fn.call_args + assert kwargs["timeout"] == expected_timeout + assert kwargs["entries"] == [mutation._to_pb()] + + def test_run_attempt_empty_request(self): + """Calling with no mutations should result in no API calls""" + mock_gapic_fn = self._make_mock_gapic([]) + instance = self._make_one(mutation_entries=[]) + instance._run_attempt() + assert mock_gapic_fn.call_count == 0 + + def test_run_attempt_partial_success_retryable(self): + """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() + mutations = [success_mutation, failure_mutation, success_mutation_2] + mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) + instance = self._make_one(mutation_entries=mutations) + instance.is_retryable = lambda x: True + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + with pytest.raises(_MutateRowsIncomplete): + instance._run_attempt() + assert instance.remaining_indices == [1] + assert 0 not in instance.errors + assert len(instance.errors[1]) == 1 + assert instance.errors[1][0].grpc_status_code == 300 + assert 2 not in instance.errors + + def test_run_attempt_partial_success_non_retryable(self): + """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() + mutations = [success_mutation, failure_mutation, success_mutation_2] + mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) + instance = self._make_one(mutation_entries=mutations) + instance.is_retryable = lambda x: False + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + instance._run_attempt() + assert instance.remaining_indices == [] + assert 0 not in instance.errors + assert len(instance.errors[1]) == 1 + assert instance.errors[1][0].grpc_status_code == 300 + assert 2 not in instance.errors From 2ebbb7270c0db53a532a9e2b4fe330bf216887e4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Apr 2024 15:58:46 -0700 Subject: [PATCH 029/360] got sync read_rows tests passing --- .../cloud/bigtable/data/_async/_read_rows.py | 10 +- google/cloud/bigtable/data/_sync/_autogen.py | 7 +- tests/unit/data/_async/test__read_rows.py | 36 +- tests/unit/data/_sync/test_autogen.py | 324 +++++++++++++++++- 4 files changed, 334 insertions(+), 43 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 57f150c7d..a6fe67847 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -31,9 +31,7 @@ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _RowSetComplete -from google.cloud.bigtable.data._helpers import _attempt_timeout_generator -from google.cloud.bigtable.data._helpers import _make_metadata -from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data import _helpers from google.api_core import retry as retries from google.api_core.retry import exponential_sleep_generator @@ -79,7 +77,7 @@ def __init__( attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), ): - self.attempt_timeout_gen = _attempt_timeout_generator( + self.attempt_timeout_gen = _helpers._attempt_timeout_generator( attempt_timeout, operation_timeout ) self.operation_timeout = operation_timeout @@ -93,7 +91,7 @@ def __init__( self.request = query._to_pb(table) self.table = table self._predicate = retries.if_exception_type(*retryable_exceptions) - self._metadata = _make_metadata( + self._metadata = _helpers._make_metadata( table.table_name, table.app_profile_id, ) @@ -109,7 +107,7 @@ def start_operation(self) -> AsyncIterable[Row]: self._predicate, exponential_sleep_generator(0.01, 60, multiplier=2), self.operation_timeout, - exception_factory=_retry_exception_factory, + exception_factory=_helpers._retry_exception_factory, ) def _read_rows_attempt(self) -> AsyncIterable[Row]: diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 795c5d886..883a1afac 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -40,6 +40,7 @@ from google.api_core.exceptions import ServiceUnavailable from google.api_core.retry import exponential_sleep_generator from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT +from google.cloud.bigtable.data import _helpers from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto from google.cloud.bigtable.data._async._mutate_rows import ( _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, @@ -120,7 +121,7 @@ def __init__( attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), ): - self.attempt_timeout_gen = _attempt_timeout_generator( + self.attempt_timeout_gen = _helpers._attempt_timeout_generator( attempt_timeout, operation_timeout ) self.operation_timeout = operation_timeout @@ -134,7 +135,7 @@ def __init__( self.request = query._to_pb(table) self.table = table self._predicate = retries.if_exception_type(*retryable_exceptions) - self._metadata = _make_metadata(table.table_name, table.app_profile_id) + self._metadata = _helpers._make_metadata(table.table_name, table.app_profile_id) self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None @@ -145,7 +146,7 @@ def start_operation(self) -> Iterable[Row]: self._predicate, exponential_sleep_generator(0.01, 60, multiplier=2), self.operation_timeout, - exception_factory=_retry_exception_factory, + exception_factory=_helpers._retry_exception_factory, ) def _read_rows_attempt(self) -> Iterable[Row]: diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 4e7797c6d..fab338fdf 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -61,7 +61,7 @@ def test_ctor(self): expected_request_timeout = 44 time_gen_mock = mock.Mock() with mock.patch( - "google.cloud.bigtable.data._async._read_rows._attempt_timeout_generator", + "google.cloud.bigtable.data._helpers._attempt_timeout_generator", time_gen_mock, ): instance = self._make_one( @@ -308,7 +308,7 @@ async def mock_stream(): yield 1 with mock.patch.object( - _ReadRowsOperationAsync, "_read_rows_attempt" + self._get_target_class(), "_read_rows_attempt" ) as mock_attempt: instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) wrapped_gen = mock_stream() @@ -330,7 +330,6 @@ async def test_retryable_ignore_repeated_rows(self): """ Duplicate rows should cause an invalid chunk error """ - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable_v2.types import ReadRowsResponse @@ -355,37 +354,8 @@ async def mock_stream(): instance = mock.Mock() instance._last_yielded_row_key = None instance._remaining_count = None - stream = _ReadRowsOperationAsync.chunk_stream(instance, mock_awaitable_stream()) + stream = self._get_target_class().chunk_stream(instance, mock_awaitable_stream()) await stream.__anext__() with pytest.raises(InvalidChunk) as exc: await stream.__anext__() assert "row keys should be strictly increasing" in str(exc.value) - - -class MockStream(_ReadRowsOperationAsync): - """ - Mock a _ReadRowsOperationAsync stream for testing - """ - - def __init__(self, items=None, errors=None, operation_timeout=None): - self.transient_errors = errors - self.operation_timeout = operation_timeout - self.next_idx = 0 - if items is None: - items = list(range(10)) - self.items = items - - def __aiter__(self): - return self - - async def __anext__(self): - if self.next_idx >= len(self.items): - raise StopAsyncIteration - item = self.items[self.next_idx] - self.next_idx += 1 - if isinstance(item, Exception): - raise item - return item - - async def aclose(self): - pass diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py index 3083da925..e7fe2ecba 100644 --- a/tests/unit/data/_sync/test_autogen.py +++ b/tests/unit/data/_sync/test_autogen.py @@ -16,6 +16,8 @@ from abc import ABC +from tests.unit.data._async.test__mutate_rows import TestMutateRowsOperation +from tests.unit.data._async.test__read_rows import TestReadRowsOperation from unittest import mock import mock import pytest @@ -25,7 +27,7 @@ import google.api_core.exceptions as core_exceptions -class TestMutateRowsOperation_SyncGen(ABC): +class TestMutateRowsOperation(ABC): def _target_class(self): from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation @@ -322,3 +324,323 @@ def test_run_attempt_partial_success_non_retryable(self): assert len(instance.errors[1]) == 1 assert instance.errors[1][0].grpc_status_code == 300 assert 2 not in instance.errors + + +class TestReadRowsOperation(ABC): + """ + Tests helper functions in the ReadRowsOperation class + in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt + is tested in test_read_rows_acceptance test_client_read_rows, and conformance tests + """ + + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + + return _ReadRowsOperation + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor(self): + from google.cloud.bigtable.data import ReadRowsQuery + + row_limit = 91 + query = ReadRowsQuery(limit=row_limit) + client = mock.Mock() + client.read_rows = mock.Mock() + client.read_rows.return_value = None + table = mock.Mock() + table._client = client + table.table_name = "test_table" + table.app_profile_id = "test_profile" + expected_operation_timeout = 42 + expected_request_timeout = 44 + time_gen_mock = mock.Mock() + with mock.patch( + "google.cloud.bigtable.data._helpers._attempt_timeout_generator", + time_gen_mock, + ): + instance = self._make_one( + query, + table, + operation_timeout=expected_operation_timeout, + attempt_timeout=expected_request_timeout, + ) + assert time_gen_mock.call_count == 1 + time_gen_mock.assert_called_once_with( + expected_request_timeout, expected_operation_timeout + ) + assert instance._last_yielded_row_key is None + assert instance._remaining_count == row_limit + assert instance.operation_timeout == expected_operation_timeout + assert client.read_rows.call_count == 0 + assert instance._metadata == [ + ( + "x-goog-request-params", + "table_name=test_table&app_profile_id=test_profile", + ) + ] + assert instance.request.table_name == table.table_name + assert instance.request.app_profile_id == table.app_profile_id + assert instance.request.rows_limit == row_limit + + @pytest.mark.parametrize( + "in_keys,last_key,expected", + [ + (["b", "c", "d"], "a", ["b", "c", "d"]), + (["a", "b", "c"], "b", ["c"]), + (["a", "b", "c"], "c", []), + (["a", "b", "c"], "d", []), + (["d", "c", "b", "a"], "b", ["d", "c"]), + ], + ) + def test_revise_request_rowset_keys(self, in_keys, last_key, expected): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + in_keys = [key.encode("utf-8") for key in in_keys] + expected = [key.encode("utf-8") for key in expected] + last_key = last_key.encode("utf-8") + sample_range = RowRangePB(start_key_open=last_key) + row_set = RowSetPB(row_keys=in_keys, row_ranges=[sample_range]) + revised = self._get_target_class()._revise_request_rowset(row_set, last_key) + assert revised.row_keys == expected + assert revised.row_ranges == [sample_range] + + @pytest.mark.parametrize( + "in_ranges,last_key,expected", + [ + ( + [{"start_key_open": "b", "end_key_closed": "d"}], + "a", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_closed": "b", "end_key_closed": "d"}], + "a", + [{"start_key_closed": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_open": "a", "end_key_closed": "d"}], + "b", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_closed": "a", "end_key_open": "d"}], + "b", + [{"start_key_open": "b", "end_key_open": "d"}], + ), + ( + [{"start_key_closed": "b", "end_key_closed": "d"}], + "b", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ([{"start_key_closed": "b", "end_key_closed": "d"}], "d", []), + ([{"start_key_closed": "b", "end_key_open": "d"}], "d", []), + ([{"start_key_closed": "b", "end_key_closed": "d"}], "e", []), + ([{"start_key_closed": "b"}], "z", [{"start_key_open": "z"}]), + ([{"start_key_closed": "b"}], "a", [{"start_key_closed": "b"}]), + ( + [{"end_key_closed": "z"}], + "a", + [{"start_key_open": "a", "end_key_closed": "z"}], + ), + ( + [{"end_key_open": "z"}], + "a", + [{"start_key_open": "a", "end_key_open": "z"}], + ), + ], + ) + def test_revise_request_rowset_ranges(self, in_ranges, last_key, expected): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + next_key = (last_key + "a").encode("utf-8") + last_key = last_key.encode("utf-8") + in_ranges = [ + RowRangePB(**{k: v.encode("utf-8") for (k, v) in r.items()}) + for r in in_ranges + ] + expected = [ + RowRangePB(**{k: v.encode("utf-8") for (k, v) in r.items()}) + for r in expected + ] + row_set = RowSetPB(row_ranges=in_ranges, row_keys=[next_key]) + revised = self._get_target_class()._revise_request_rowset(row_set, last_key) + assert revised.row_keys == [next_key] + assert revised.row_ranges == expected + + @pytest.mark.parametrize("last_key", ["a", "b", "c"]) + def test_revise_request_full_table(self, last_key): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + last_key = last_key.encode("utf-8") + row_set = RowSetPB() + for selected_set in [row_set, None]: + revised = self._get_target_class()._revise_request_rowset( + selected_set, last_key + ) + assert revised.row_keys == [] + assert len(revised.row_ranges) == 1 + assert revised.row_ranges[0] == RowRangePB(start_key_open=last_key) + + def test_revise_to_empty_rowset(self): + """revising to an empty rowset should raise error""" + from google.cloud.bigtable.data.exceptions import _RowSetComplete + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + row_keys = [b"a", b"b", b"c"] + row_range = RowRangePB(end_key_open=b"c") + row_set = RowSetPB(row_keys=row_keys, row_ranges=[row_range]) + with pytest.raises(_RowSetComplete): + self._get_target_class()._revise_request_rowset(row_set, b"d") + + @pytest.mark.parametrize( + "start_limit,emit_num,expected_limit", + [ + (10, 0, 10), + (10, 1, 9), + (10, 10, 0), + (None, 10, None), + (None, 0, None), + (4, 2, 2), + ], + ) + def test_revise_limit(self, start_limit, emit_num, expected_limit): + """ + revise_limit should revise the request's limit field + - if limit is 0 (unlimited), it should never be revised + - if start_limit-emit_num == 0, the request should end early + - if the number emitted exceeds the new limit, an exception should + should be raised (tested in test_revise_limit_over_limit) + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse + + def awaitable_stream(): + def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) + + return mock_stream() + + query = ReadRowsQuery(limit=start_limit) + table = mock.Mock() + table.table_name = "table_name" + table.app_profile_id = "app_profile_id" + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + for val in instance.chunk_stream(awaitable_stream()): + pass + assert instance._remaining_count == expected_limit + + @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) + def test_revise_limit_over_limit(self, start_limit, emit_num): + """ + Should raise runtime error if we get in state where emit_num > start_num + (unless start_num == 0, which represents unlimited) + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse + from google.cloud.bigtable.data.exceptions import InvalidChunk + + def awaitable_stream(): + def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) + + return mock_stream() + + query = ReadRowsQuery(limit=start_limit) + table = mock.Mock() + table.table_name = "table_name" + table.app_profile_id = "app_profile_id" + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + with pytest.raises(InvalidChunk) as e: + for val in instance.chunk_stream(awaitable_stream()): + pass + assert "emit count exceeds row limit" in str(e.value) + + def test_close(self): + """ + should be able to close a stream safely with aclose. + Closed generators should raise StopIteration on next yield + """ + + def mock_stream(): + while True: + yield 1 + + with mock.patch.object( + self._get_target_class(), "_read_rows_attempt" + ) as mock_attempt: + instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) + wrapped_gen = mock_stream() + mock_attempt.return_value = wrapped_gen + gen = instance.start_operation() + gen.__next__() + gen.close() + with pytest.raises(StopIteration): + gen.__next__() + gen.close() + with pytest.raises(StopIteration): + wrapped_gen.__next__() + + def test_retryable_ignore_repeated_rows(self): + """Duplicate rows should cause an invalid chunk error""" + from google.cloud.bigtable.data.exceptions import InvalidChunk + from google.cloud.bigtable_v2.types import ReadRowsResponse + + row_key = b"duplicate" + + def mock_awaitable_stream(): + def mock_stream(): + while True: + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) + ] + ) + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) + ] + ) + + return mock_stream() + + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + stream = self._get_target_class().chunk_stream( + instance, mock_awaitable_stream() + ) + stream.__next__() + with pytest.raises(InvalidChunk) as exc: + stream.__next__() + assert "row keys should be strictly increasing" in str(exc.value) From 0bfb257ee3796b95db8965859d3613eb8b84fa01 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Apr 2024 17:18:29 -0700 Subject: [PATCH 030/360] got flow control tests passing --- .../bigtable/data/_async/mutations_batcher.py | 4 +- google/cloud/bigtable/data/_sync/_autogen.py | 4 +- .../cloud/bigtable/data/_sync/sync_gen.yaml | 1 + .../data/_async/test_mutations_batcher.py | 108 +++++--- tests/unit/data/_sync/test_autogen.py | 241 ++++++++++++++++++ 5 files changed, 312 insertions(+), 46 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 5d5dd535e..8195a1929 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -20,7 +20,6 @@ import warnings from collections import deque -from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data._helpers import _get_retryable_errors @@ -28,13 +27,14 @@ from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync -from google.cloud.bigtable.data._async._mutate_rows import ( +from google.cloud.bigtable.data.mutations import ( _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, ) from google.cloud.bigtable.data.mutations import Mutation if TYPE_CHECKING: from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data.mutations import RowMutationEntry # used to make more readable default values _MB_SIZE = 1024 * 1024 diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 883a1afac..b616df2c8 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -29,6 +29,7 @@ import atexit import functools import os +import threading import time import warnings @@ -42,9 +43,6 @@ from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT from google.cloud.bigtable.data import _helpers from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto -from google.cloud.bigtable.data._async._mutate_rows import ( - _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, -) from google.cloud.bigtable.data._async._read_rows import _ResetRow from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data._helpers import RowKeySamples diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index 4915058ea..1909a3dcf 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -27,6 +27,7 @@ added_imports: - "from typing import Generator, Iterable, Iterator" - "from grpc import Channel" - "import google.cloud.bigtable.data.exceptions as bt_exceptions" + - "import threading" classes: # Specify transformations for individual classes - path: google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index cca7c9824..b790bcfb8 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -27,20 +27,24 @@ from mock import AsyncMock # type: ignore -def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation - - class Test_FlowControl: - def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): + + @staticmethod + def _target_class(): from google.cloud.bigtable.data._async.mutations_batcher import ( _FlowControlAsync, ) + return _FlowControlAsync - return _FlowControlAsync(max_mutation_count, max_mutation_bytes) + def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): + return self._target_class()(max_mutation_count, max_mutation_bytes) + + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation def test_ctor(self): max_mutation_count = 9 @@ -138,7 +142,7 @@ async def test_remove_from_flow_value_update( instance = self._make_one() instance._in_flight_mutation_count = existing_count instance._in_flight_mutation_bytes = existing_size - mutation = _make_mutation(added_count, added_size) + mutation = self._make_mutation(added_count, added_size) await instance.remove_from_flow(mutation) assert instance._in_flight_mutation_count == new_count assert instance._in_flight_mutation_bytes == new_size @@ -146,6 +150,7 @@ async def test_remove_from_flow_value_update( @pytest.mark.asyncio async def test__remove_from_flow_unlock(self): """capacity condition should notify after mutation is complete""" + import inspect instance = self._make_one(10, 10) instance._in_flight_mutation_count = 10 instance._in_flight_mutation_bytes = 10 @@ -155,35 +160,44 @@ async def task_routine(): await instance._capacity_condition.wait_for( lambda: instance._has_capacity(1, 1) ) - - task = asyncio.create_task(task_routine()) + if inspect.iscoroutinefunction(task_routine): + # for async class, build task to test flow unlock + task = asyncio.create_task(task_routine()) + task_alive = lambda: not task.done() + else: + # this branch will be tested in sync version of this test + import threading + thread = threading.Thread(target=task_routine) + thread.start() + task_alive = thread.is_alive await asyncio.sleep(0.05) # should be blocked due to capacity - assert task.done() is False + assert task_alive() is True # try changing size - mutation = _make_mutation(count=0, size=5) + mutation = self._make_mutation(count=0, size=5) + await instance.remove_from_flow([mutation]) await asyncio.sleep(0.05) assert instance._in_flight_mutation_count == 10 assert instance._in_flight_mutation_bytes == 5 - assert task.done() is False + assert task_alive() is True # try changing count instance._in_flight_mutation_bytes = 10 - mutation = _make_mutation(count=5, size=0) + mutation = self._make_mutation(count=5, size=0) await instance.remove_from_flow([mutation]) await asyncio.sleep(0.05) assert instance._in_flight_mutation_count == 5 assert instance._in_flight_mutation_bytes == 10 - assert task.done() is False + assert task_alive() is True # try changing both instance._in_flight_mutation_count = 10 - mutation = _make_mutation(count=5, size=5) + mutation = self._make_mutation(count=5, size=5) await instance.remove_from_flow([mutation]) await asyncio.sleep(0.05) assert instance._in_flight_mutation_count == 5 assert instance._in_flight_mutation_bytes == 5 # task should be complete - assert task.done() is True + assert task_alive() is False @pytest.mark.asyncio @pytest.mark.parametrize( @@ -210,7 +224,7 @@ async def test_add_to_flow(self, mutations, count_cap, size_cap, expected_result """ Test batching with various flow control settings """ - mutation_objs = [_make_mutation(count=m[0], size=m[1]) for m in mutations] + mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] instance = self._make_one(count_cap, size_cap) i = 0 async for batch in instance.add_to_flow(mutation_objs): @@ -242,11 +256,16 @@ async def test_add_to_flow_max_mutation_limits( Test flow control running up against the max API limit Should submit request early, even if the flow control has room for more """ - with mock.patch( + async_patch = mock.patch( "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", max_limit, - ): - mutation_objs = [_make_mutation(count=m[0], size=m[1]) for m in mutations] + ) + sync_patch = mock.patch( + "google.cloud.bigtable.data._sync._autogen._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + max_limit, + ) + with async_patch, sync_patch: + mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] # flow control has no limits except API restrictions instance = self._make_one(float("inf"), float("inf")) i = 0 @@ -269,8 +288,8 @@ async def test_add_to_flow_oversize(self): mutations over the flow control limits should still be accepted """ instance = self._make_one(2, 3) - large_size_mutation = _make_mutation(count=1, size=10) - large_count_mutation = _make_mutation(count=10, size=1) + large_size_mutation = self._make_mutation(count=1, size=10) + large_count_mutation = self._make_mutation(count=10, size=1) results = [out async for out in instance.add_to_flow([large_size_mutation])] assert len(results) == 1 await instance.remove_from_flow(results[0]) @@ -303,6 +322,13 @@ def _make_one(self, table=None, **kwargs): return self._get_target_class()(table, **kwargs) + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + @mock.patch( "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer" ) @@ -580,7 +606,7 @@ async def test_append_outside_flow_limits(self): async with self._make_one( flow_control_max_mutation_count=1, flow_control_max_bytes=1 ) as instance: - oversized_entry = _make_mutation(count=0, size=2) + oversized_entry = self._make_mutation(count=0, size=2) await instance.append(oversized_entry) assert instance._staged_entries == [oversized_entry] assert instance._staged_count == 0 @@ -589,7 +615,7 @@ async def test_append_outside_flow_limits(self): async with self._make_one( flow_control_max_mutation_count=1, flow_control_max_bytes=1 ) as instance: - overcount_entry = _make_mutation(count=2, size=0) + overcount_entry = self._make_mutation(count=2, size=0) await instance.append(overcount_entry) assert instance._staged_entries == [overcount_entry] assert instance._staged_count == 2 @@ -616,11 +642,11 @@ async def mock_call(*args, **kwargs): op_mock.side_effect = mock_call # append a mutation just under the size limit - await instance.append(_make_mutation(size=99)) + await instance.append(self._make_mutation(size=99)) # append a bunch of entries back-to-back in a loop num_entries = 10 for _ in range(num_entries): - await instance.append(_make_mutation(size=1)) + await instance.append(self._make_mutation(size=1)) # let any flush jobs finish await asyncio.gather(*instance._flush_jobs) # should have only flushed once, with large mutation and first mutation in loop @@ -653,7 +679,7 @@ async def test_append( assert instance._staged_count == 0 assert instance._staged_bytes == 0 assert instance._staged_entries == [] - mutation = _make_mutation(count=mutation_count, size=mutation_bytes) + mutation = self._make_mutation(count=mutation_count, size=mutation_bytes) with mock.patch.object(instance, "_schedule_flush") as flush_mock: await instance.append(mutation) assert flush_mock.call_count == bool(expect_flush) @@ -671,7 +697,7 @@ async def test_append_multiple_sequentially(self): assert instance._staged_count == 0 assert instance._staged_bytes == 0 assert instance._staged_entries == [] - mutation = _make_mutation(count=2, size=3) + mutation = self._make_mutation(count=2, size=3) with mock.patch.object(instance, "_schedule_flush") as flush_mock: await instance.append(mutation) assert flush_mock.call_count == 0 @@ -698,7 +724,7 @@ async def test_flush_flow_control_concurrent_requests(self): import time num_calls = 10 - fake_mutations = [_make_mutation(count=1) for _ in range(num_calls)] + fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] async with self._make_one(flow_control_max_mutation_count=1) as instance: with mock.patch.object( instance, "_execute_mutate_rows", AsyncMock() @@ -717,7 +743,7 @@ async def mock_call(*args, **kwargs): # make room for new mutations for i in range(num_calls): await instance._flow_control.remove_from_flow( - [_make_mutation(count=1)] + [self._make_mutation(count=1)] ) await asyncio.sleep(0.01) # allow flushes to complete @@ -775,7 +801,7 @@ async def gen(x): yield x flow_mock.side_effect = lambda x: gen(x) - mutations = [_make_mutation(count=1, size=1)] * num_entries + mutations = [self._make_mutation(count=1, size=1)] * num_entries await instance._flush_internal(mutations) assert instance._entries_processed_since_last_raise == num_entries assert execute_mock.call_count == 1 @@ -791,7 +817,7 @@ async def test_flush_clears_job_list(self): """ async with self._make_one() as instance: with mock.patch.object(instance, "_flush_internal", AsyncMock()): - mutations = [_make_mutation(count=1, size=1)] + mutations = [self._make_mutation(count=1, size=1)] instance._staged_entries = mutations assert instance._flush_jobs == set() new_job = instance._schedule_flush() @@ -836,7 +862,7 @@ async def gen(x): yield x flow_mock.side_effect = lambda x: gen(x) - mutations = [_make_mutation(count=1, size=1)] * num_entries + mutations = [self._make_mutation(count=1, size=1)] * num_entries await instance._flush_internal(mutations) assert instance._entries_processed_since_last_raise == num_entries assert execute_mock.call_count == 1 @@ -870,7 +896,7 @@ async def gen(num): async def test_timer_flush_end_to_end(self): """Flush should automatically trigger after flush_interval""" num_nutations = 10 - mutations = [_make_mutation(count=2, size=2)] * num_nutations + mutations = [self._make_mutation(count=2, size=2)] * num_nutations async with self._make_one(flush_interval=0.05) as instance: instance._table.default_operation_timeout = 10 @@ -902,7 +928,7 @@ async def test__execute_mutate_rows(self, mutate_rows): table.default_mutate_rows_attempt_timeout = 13 table.default_mutate_rows_retryable_errors = () async with self._make_one(table) as instance: - batch = [_make_mutation()] + batch = [self._make_mutation()] result = await instance._execute_mutate_rows(batch) assert start_operation.call_count == 1 args, kwargs = mutate_rows.call_args @@ -932,7 +958,7 @@ async def test__execute_mutate_rows_returns_errors(self, mutate_rows): table.default_mutate_rows_attempt_timeout = 13 table.default_mutate_rows_retryable_errors = () async with self._make_one(table) as instance: - batch = [_make_mutation()] + batch = [self._make_mutation()] result = await instance._execute_mutate_rows(batch) assert len(result) == 2 assert result[0] == err1 @@ -1058,7 +1084,7 @@ async def test_timeout_args_passed(self, mutate_rows): assert instance._operation_timeout == expected_operation_timeout assert instance._attempt_timeout == expected_attempt_timeout # make simulated gapic call - await instance._execute_mutate_rows([_make_mutation()]) + await instance._execute_mutate_rows([self._make_mutation()]) assert mutate_rows.call_count == 1 kwargs = mutate_rows.call_args[1] assert kwargs["operation_timeout"] == expected_operation_timeout @@ -1173,7 +1199,7 @@ async def test_customizable_retryable_errors( expected_predicate = lambda a: a in expected_retryables # noqa predicate_builder_mock.return_value = expected_predicate retry_fn_mock.side_effect = RuntimeError("stop early") - mutation = _make_mutation(count=1, size=1) + mutation = self._make_mutation(count=1, size=1) await instance._execute_mutate_rows([mutation]) # passed in errors should be used to build the predicate predicate_builder_mock.assert_called_once_with( diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py index e7fe2ecba..4eca799da 100644 --- a/tests/unit/data/_sync/test_autogen.py +++ b/tests/unit/data/_sync/test_autogen.py @@ -18,9 +18,12 @@ from abc import ABC from tests.unit.data._async.test__mutate_rows import TestMutateRowsOperation from tests.unit.data._async.test__read_rows import TestReadRowsOperation +from tests.unit.data._async.test_mutations_batcher import Test_FlowControl from unittest import mock import mock import pytest +import threading +import time from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -644,3 +647,241 @@ def mock_stream(): with pytest.raises(InvalidChunk) as exc: stream.__next__() assert "row keys should be strictly increasing" in str(exc.value) + + +class Test_FlowControl(ABC): + @staticmethod + def _target_class(): + from google.cloud.bigtable.data._sync.mutations_batcher import _FlowControl + + return _FlowControl + + def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): + return self._target_class()(max_mutation_count, max_mutation_bytes) + + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def test_ctor(self): + max_mutation_count = 9 + max_mutation_bytes = 19 + instance = self._make_one(max_mutation_count, max_mutation_bytes) + assert instance._max_mutation_count == max_mutation_count + assert instance._max_mutation_bytes == max_mutation_bytes + assert instance._in_flight_mutation_count == 0 + assert instance._in_flight_mutation_bytes == 0 + assert isinstance(instance._capacity_condition, threading.Condition) + + def test_ctor_invalid_values(self): + """Test that values are positive, and fit within expected limits""" + with pytest.raises(ValueError) as e: + self._make_one(0, 1) + assert "max_mutation_count must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._make_one(1, 0) + assert "max_mutation_bytes must be greater than 0" in str(e.value) + + @pytest.mark.parametrize( + "max_count,max_size,existing_count,existing_size,new_count,new_size,expected", + [ + (1, 1, 0, 0, 0, 0, True), + (1, 1, 1, 1, 1, 1, False), + (10, 10, 0, 0, 0, 0, True), + (10, 10, 0, 0, 9, 9, True), + (10, 10, 0, 0, 11, 9, True), + (10, 10, 0, 1, 11, 9, True), + (10, 10, 1, 0, 11, 9, False), + (10, 10, 0, 0, 9, 11, True), + (10, 10, 1, 0, 9, 11, True), + (10, 10, 0, 1, 9, 11, False), + (10, 1, 0, 0, 1, 0, True), + (1, 10, 0, 0, 0, 8, True), + (float("inf"), float("inf"), 0, 0, 10000000000.0, 10000000000.0, True), + (8, 8, 0, 0, 10000000000.0, 10000000000.0, True), + (12, 12, 6, 6, 5, 5, True), + (12, 12, 5, 5, 6, 6, True), + (12, 12, 6, 6, 6, 6, True), + (12, 12, 6, 6, 7, 7, False), + (12, 12, 0, 0, 13, 13, True), + (12, 12, 12, 0, 0, 13, True), + (12, 12, 0, 12, 13, 0, True), + (12, 12, 1, 1, 13, 13, False), + (12, 12, 1, 1, 0, 13, False), + (12, 12, 1, 1, 13, 0, False), + ], + ) + def test__has_capacity( + self, + max_count, + max_size, + existing_count, + existing_size, + new_count, + new_size, + expected, + ): + """_has_capacity should return True if the new mutation will will not exceed the max count or size""" + instance = self._make_one(max_count, max_size) + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + assert instance._has_capacity(new_count, new_size) == expected + + @pytest.mark.parametrize( + "existing_count,existing_size,added_count,added_size,new_count,new_size", + [ + (0, 0, 0, 0, 0, 0), + (2, 2, 1, 1, 1, 1), + (2, 0, 1, 0, 1, 0), + (0, 2, 0, 1, 0, 1), + (10, 10, 0, 0, 10, 10), + (10, 10, 5, 5, 5, 5), + (0, 0, 1, 1, -1, -1), + ], + ) + def test_remove_from_flow_value_update( + self, + existing_count, + existing_size, + added_count, + added_size, + new_count, + new_size, + ): + """completed mutations should lower the inflight values""" + instance = self._make_one() + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + mutation = self._make_mutation(added_count, added_size) + instance.remove_from_flow(mutation) + assert instance._in_flight_mutation_count == new_count + assert instance._in_flight_mutation_bytes == new_size + + def test__remove_from_flow_unlock(self): + """capacity condition should notify after mutation is complete""" + import inspect + + instance = self._make_one(10, 10) + instance._in_flight_mutation_count = 10 + instance._in_flight_mutation_bytes = 10 + + def task_routine(): + with instance._capacity_condition: + instance._capacity_condition.wait_for( + lambda: instance._has_capacity(1, 1) + ) + + if inspect.iscoroutinefunction(task_routine): + task = threading.Thread(task_routine()) + task_alive = lambda: not task.done() + else: + import threading + + thread = threading.Thread(target=task_routine) + thread.start() + task_alive = thread.is_alive + time.sleep(0.05) + assert task_alive() is True + mutation = self._make_mutation(count=0, size=5) + instance.remove_from_flow([mutation]) + time.sleep(0.05) + assert instance._in_flight_mutation_count == 10 + assert instance._in_flight_mutation_bytes == 5 + assert task_alive() is True + instance._in_flight_mutation_bytes = 10 + mutation = self._make_mutation(count=5, size=0) + instance.remove_from_flow([mutation]) + time.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 10 + assert task_alive() is True + instance._in_flight_mutation_count = 10 + mutation = self._make_mutation(count=5, size=5) + instance.remove_from_flow([mutation]) + time.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 5 + assert task_alive() is False + + @pytest.mark.parametrize( + "mutations,count_cap,size_cap,expected_results", + [ + ([(5, 5), (1, 1), (1, 1)], 10, 10, [[(5, 5), (1, 1), (1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 1, 1, [[(1, 1)], [(1, 1)], [(1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 2, 10, [[(1, 1), (1, 1)], [(1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 10, 2, [[(1, 1), (1, 1)], [(1, 1)]]), + ( + [(1, 1), (5, 5), (4, 1), (1, 4), (1, 1)], + 5, + 5, + [[(1, 1)], [(5, 5)], [(4, 1), (1, 4)], [(1, 1)]], + ), + ], + ) + def test_add_to_flow(self, mutations, count_cap, size_cap, expected_results): + """Test batching with various flow control settings""" + mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] + instance = self._make_one(count_cap, size_cap) + i = 0 + for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + assert len(batch[j].mutations) == expected_batch[j][0] + assert batch[j].size() == expected_batch[j][1] + instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + @pytest.mark.parametrize( + "mutations,max_limit,expected_results", + [ + ([(1, 1)] * 11, 10, [[(1, 1)] * 10, [(1, 1)]]), + ([(1, 1)] * 10, 1, [[(1, 1)] for _ in range(10)]), + ([(1, 1)] * 10, 2, [[(1, 1), (1, 1)] for _ in range(5)]), + ], + ) + def test_add_to_flow_max_mutation_limits( + self, mutations, max_limit, expected_results + ): + """ + Test flow control running up against the max API limit + Should submit request early, even if the flow control has room for more + """ + async_patch = mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + max_limit, + ) + sync_patch = mock.patch( + "google.cloud.bigtable.data._sync._autogen._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + max_limit, + ) + with async_patch, sync_patch: + mutation_objs = [ + self._make_mutation(count=m[0], size=m[1]) for m in mutations + ] + instance = self._make_one(float("inf"), float("inf")) + i = 0 + for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + assert len(batch[j].mutations) == expected_batch[j][0] + assert batch[j].size() == expected_batch[j][1] + instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + def test_add_to_flow_oversize(self): + """mutations over the flow control limits should still be accepted""" + instance = self._make_one(2, 3) + large_size_mutation = self._make_mutation(count=1, size=10) + large_count_mutation = self._make_mutation(count=10, size=1) + results = [out for out in instance.add_to_flow([large_size_mutation])] + assert len(results) == 1 + instance.remove_from_flow(results[0]) + count_results = [out for out in instance.add_to_flow(large_count_mutation)] + assert len(count_results) == 1 From 5b9e96729fe970a1d48819adde948ba22c8c61e2 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Apr 2024 12:34:33 -0700 Subject: [PATCH 031/360] made some fixes to mutations batcher --- .../bigtable/data/_async/mutations_batcher.py | 4 +-- google/cloud/bigtable/data/_sync/_autogen.py | 26 ++++++++++++++++++- .../bigtable/data/_sync/mutations_batcher.py | 10 ++++++- .../cloud/bigtable/data/_sync/sync_gen.yaml | 3 ++- 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 8195a1929..4dfd0df1a 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -274,9 +274,7 @@ async def timer_routine(self, interval: float): # add new flush task to list if not self.closed and self._staged_entries: self._schedule_flush() - - timer_task = asyncio.create_task(timer_routine(self, interval)) - return timer_task + return self._create_bg_task(timer_routine, self, interval) async def append(self, mutation_entry: RowMutationEntry): """ diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index b616df2c8..677c95c47 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -27,6 +27,7 @@ from typing import Set from typing import cast import atexit +import concurrent.futures import functools import os import threading @@ -574,7 +575,30 @@ def __init__( def _start_flush_timer( self, interval: float | None ) -> concurrent.futures.Future[None]: - raise NotImplementedError("Function not implemented in sync class") + """ + Set up a background task to flush the batcher every interval seconds + + If interval is None, an empty future is returned + + Args: + - flush_interval: Automatically flush every flush_interval seconds. + If None, no time-based flushing is performed. + Returns: + - asyncio.Future that represents the background task + """ + if interval is None or self.closed: + empty_future: concurrent.futures.Future[None] = concurrent.futures.Future() + empty_future.set_result(None) + return empty_future + + def timer_routine(self, interval: float): + """Triggers new flush tasks every `interval` seconds""" + while not self.closed: + time.sleep(interval) + if not self.closed and self._staged_entries: + self._schedule_flush() + + return self._create_bg_task(timer_routine, self, interval) def append(self, mutation_entry: RowMutationEntry): """ diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index d7cbc428c..13f17a642 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -14,6 +14,7 @@ # from __future__ import annotations +from concurrent.futures import ThreadPoolExecutor from google.cloud.bigtable.data._sync._autogen import _FlowControl_SyncGen from google.cloud.bigtable.data._sync._autogen import MutationsBatcher_SyncGen @@ -22,5 +23,12 @@ class _FlowControl(_FlowControl_SyncGen): pass + class MutationsBatcher(MutationsBatcher_SyncGen): - pass + + def __init__(self, *args, **kwargs): + self._executor = ThreadPoolExecutor(max_workers=8) + super().__init__(*args, **kwargs) + + def _create_bg_task(self, func, *args, **kwargs): + return self._executor.submit(func, *args, **kwargs) diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index 1909a3dcf..a9ebe5db4 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -28,6 +28,7 @@ added_imports: - "from grpc import Channel" - "import google.cloud.bigtable.data.exceptions as bt_exceptions" - "import threading" + - "import concurrent.futures" classes: # Specify transformations for individual classes - path: google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync @@ -40,7 +41,7 @@ classes: # Specify transformations for individual classes autogen_sync_name: MutationsBatcher_SyncGen concrete_path: google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher pass_methods: ["close", "_wait_for_batch_results"] - error_methods: ["_create_bg_task", "_start_flush_timer"] + error_methods: ["_create_bg_task"] - path: google.cloud.bigtable.data._async.mutations_batcher._FlowControlAsync autogen_sync_name: _FlowControl_SyncGen concrete_path: google.cloud.bigtable.data._sync.mutations_batcher._FlowControl From 717c4012880014c8703092ced2622970d667bb5c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Apr 2024 13:13:01 -0700 Subject: [PATCH 032/360] improved modularity of mutation batcher tests --- .../bigtable/data/_sync/mutations_batcher.py | 2 + .../data/_async/test_mutations_batcher.py | 488 +++-- tests/unit/data/_sync/test_autogen.py | 1596 ++++++++--------- 3 files changed, 1023 insertions(+), 1063 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 13f17a642..8a1332f3e 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -19,6 +19,8 @@ from google.cloud.bigtable.data._sync._autogen import _FlowControl_SyncGen from google.cloud.bigtable.data._sync._autogen import MutationsBatcher_SyncGen +# import required so MutationsBatcher_SyncGen can create _MutateRowsOperation +import google.cloud.bigtable.data._sync._mutate_rows # noqa: F401 class _FlowControl(_FlowControl_SyncGen): pass diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index b790bcfb8..f18f5884b 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -307,6 +307,12 @@ def _get_target_class(self): return MutationsBatcherAsync + def _get_mutate_rows_class_path(self): + # location the MutateRowsOperation class is imported for mocking + # will be different in sync vs async versions + + return "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync" + def _make_one(self, table=None, **kwargs): from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import ServiceUnavailable @@ -329,130 +335,121 @@ def _make_mutation(count=1, size=1): mutation.mutations = [mock.Mock()] * count return mutation - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer" - ) @pytest.mark.asyncio - async def test_ctor_defaults(self, flush_timer_mock): - flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = [Exception] - async with self._make_one(table) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._max_mutation_count == 100000 - assert instance._flow_control._max_mutation_bytes == 104857600 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert ( - instance._operation_timeout - == table.default_mutate_rows_operation_timeout - ) - assert ( - instance._attempt_timeout == table.default_mutate_rows_attempt_timeout - ) - assert ( - instance._retryable_errors == table.default_mutate_rows_retryable_errors - ) - await asyncio.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == 5 - assert isinstance(instance._flush_timer, asyncio.Future) + async def test_ctor_defaults(self): + with mock.patch.object(self._get_target_class(), "_start_flush_timer", return_value=asyncio.Future()) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = [Exception] + async with self._make_one(table) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._max_mutation_count == 100000 + assert instance._flow_control._max_mutation_bytes == 104857600 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert ( + instance._operation_timeout + == table.default_mutate_rows_operation_timeout + ) + assert ( + instance._attempt_timeout == table.default_mutate_rows_attempt_timeout + ) + assert ( + instance._retryable_errors == table.default_mutate_rows_retryable_errors + ) + await asyncio.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == 5 + assert isinstance(instance._flush_timer, asyncio.Future) - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer", - ) @pytest.mark.asyncio - async def test_ctor_explicit(self, flush_timer_mock): + async def test_ctor_explicit(self): """Test with explicit parameters""" - flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) - table = mock.Mock() - flush_interval = 20 - flush_limit_count = 17 - flush_limit_bytes = 19 - flow_control_max_mutation_count = 1001 - flow_control_max_bytes = 12 - operation_timeout = 11 - attempt_timeout = 2 - retryable_errors = [Exception] - async with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - flow_control_max_mutation_count=flow_control_max_mutation_count, - flow_control_max_bytes=flow_control_max_bytes, - batch_operation_timeout=operation_timeout, - batch_attempt_timeout=attempt_timeout, - batch_retryable_errors=retryable_errors, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert ( - instance._flow_control._max_mutation_count - == flow_control_max_mutation_count - ) - assert instance._flow_control._max_mutation_bytes == flow_control_max_bytes - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert instance._operation_timeout == operation_timeout - assert instance._attempt_timeout == attempt_timeout - assert instance._retryable_errors == retryable_errors - await asyncio.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == flush_interval - assert isinstance(instance._flush_timer, asyncio.Future) - - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer" - ) + with mock.patch.object(self._get_target_class(), "_start_flush_timer", return_value=asyncio.Future()) as flush_timer_mock: + table = mock.Mock() + flush_interval = 20 + flush_limit_count = 17 + flush_limit_bytes = 19 + flow_control_max_mutation_count = 1001 + flow_control_max_bytes = 12 + operation_timeout = 11 + attempt_timeout = 2 + retryable_errors = [Exception] + async with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=operation_timeout, + batch_attempt_timeout=attempt_timeout, + batch_retryable_errors=retryable_errors, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert ( + instance._flow_control._max_mutation_count + == flow_control_max_mutation_count + ) + assert instance._flow_control._max_mutation_bytes == flow_control_max_bytes + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert instance._operation_timeout == operation_timeout + assert instance._attempt_timeout == attempt_timeout + assert instance._retryable_errors == retryable_errors + await asyncio.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == flush_interval + assert isinstance(instance._flush_timer, asyncio.Future) + @pytest.mark.asyncio - async def test_ctor_no_flush_limits(self, flush_timer_mock): + async def test_ctor_no_flush_limits(self): """Test with None for flush limits""" - flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = () - flush_interval = None - flush_limit_count = None - flush_limit_bytes = None - async with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._staged_entries == [] - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - await asyncio.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] is None - assert isinstance(instance._flush_timer, asyncio.Future) + with mock.patch.object(self._get_target_class(), "_start_flush_timer", return_value=asyncio.Future()) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = () + flush_interval = None + flush_limit_count = None + flush_limit_bytes = None + async with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._staged_entries == [] + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + await asyncio.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] is None + assert isinstance(instance._flush_timer, asyncio.Future) @pytest.mark.asyncio async def test_ctor_invalid_values(self): @@ -496,87 +493,79 @@ def test_default_argument_consistency(self): == batcher_init_signature[arg_name].default ) - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) @pytest.mark.asyncio - async def test__start_flush_timer_w_None(self, flush_mock): + async def test__start_flush_timer_w_None(self): """Empty timer should return immediately""" - async with self._make_one() as instance: - with mock.patch("asyncio.sleep") as sleep_mock: - await instance._start_flush_timer(None) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 + with mock.patch.object(self._get_target_class(), "_schedule_flush") as flush_mock: + async with self._make_one() as instance: + with mock.patch("asyncio.sleep") as sleep_mock: + await instance._start_flush_timer(None) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) @pytest.mark.asyncio - async def test__start_flush_timer_call_when_closed(self, flush_mock): + async def test__start_flush_timer_call_when_closed(self,): """closed batcher's timer should return immediately""" - async with self._make_one() as instance: - await instance.close() - flush_mock.reset_mock() - with mock.patch("asyncio.sleep") as sleep_mock: - await instance._start_flush_timer(1) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 + with mock.patch.object(self._get_target_class(), "_schedule_flush") as flush_mock: + async with self._make_one() as instance: + await instance.close() + flush_mock.reset_mock() + with mock.patch("asyncio.sleep") as sleep_mock: + await instance._start_flush_timer(1) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) @pytest.mark.asyncio - async def test__flush_timer(self, flush_mock): + async def test__flush_timer(self): """Timer should continue to call _schedule_flush in a loop""" - expected_sleep = 12 - async with self._make_one(flush_interval=expected_sleep) as instance: - instance._staged_entries = [mock.Mock()] - loop_num = 3 - with mock.patch("asyncio.sleep") as sleep_mock: - sleep_mock.side_effect = [None] * loop_num + [asyncio.CancelledError()] - try: - await instance._flush_timer - except asyncio.CancelledError: - pass - assert sleep_mock.call_count == loop_num + 1 - sleep_mock.assert_called_with(expected_sleep) - assert flush_mock.call_count == loop_num - - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) + with mock.patch.object(self._get_target_class(), "_schedule_flush") as flush_mock: + expected_sleep = 12 + async with self._make_one(flush_interval=expected_sleep) as instance: + instance._staged_entries = [mock.Mock()] + loop_num = 3 + with mock.patch("asyncio.sleep") as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [ZeroDivisionError("expected")] + try: + await instance._flush_timer + except ZeroDivisionError: + # replace with np-op so there are no issues on close + instance._flush_timer = asyncio.Future() + assert sleep_mock.call_count == loop_num + 1 + sleep_mock.assert_called_with(expected_sleep) + assert flush_mock.call_count == loop_num + @pytest.mark.asyncio - async def test__flush_timer_no_mutations(self, flush_mock): + async def test__flush_timer_no_mutations(self): """Timer should not flush if no new mutations have been staged""" - expected_sleep = 12 - async with self._make_one(flush_interval=expected_sleep) as instance: - loop_num = 3 - with mock.patch("asyncio.sleep") as sleep_mock: - sleep_mock.side_effect = [None] * loop_num + [asyncio.CancelledError()] - try: - await instance._flush_timer - except asyncio.CancelledError: - pass - assert sleep_mock.call_count == loop_num + 1 - sleep_mock.assert_called_with(expected_sleep) - assert flush_mock.call_count == 0 + with mock.patch.object(self._get_target_class(), "_schedule_flush") as flush_mock: + expected_sleep = 12 + async with self._make_one(flush_interval=expected_sleep) as instance: + loop_num = 3 + with mock.patch("asyncio.sleep") as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] + try: + await instance._flush_timer + except TabError: + # replace with np-op so there are no issues on close + instance._flush_timer = asyncio.Future() + assert sleep_mock.call_count == loop_num + 1 + sleep_mock.assert_called_with(expected_sleep) + assert flush_mock.call_count == 0 - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) @pytest.mark.asyncio - async def test__flush_timer_close(self, flush_mock): + async def test__flush_timer_close(self): """Timer should continue terminate after close""" - async with self._make_one() as instance: - with mock.patch("asyncio.sleep"): - # let task run in background - await asyncio.sleep(0.5) - assert instance._flush_timer.done() is False - # close the batcher - await instance.close() - await asyncio.sleep(0.1) - # task should be complete - assert instance._flush_timer.done() is True + with mock.patch.object(self._get_target_class(), "_schedule_flush") as flush_mock: + async with self._make_one() as instance: + with mock.patch("asyncio.sleep"): + # let task run in background + await asyncio.sleep(0.5) + assert instance._flush_timer.done() is False + # close the batcher + await instance.close() + await asyncio.sleep(0.1) + # task should be complete + assert instance._flush_timer.done() is True @pytest.mark.asyncio async def test_append_closed(self): @@ -628,12 +617,8 @@ async def test_append_flush_runs_after_limit_hit(self): If the user appends a bunch of entries above the flush limits back-to-back, it should still flush in a single task """ - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) - with mock.patch.object( - MutationsBatcherAsync, "_execute_mutate_rows" + self._get_target_class(), "_execute_mutate_rows" ) as op_mock: async with self._make_one(flush_limit_bytes=100) as instance: # mock network calls @@ -915,57 +900,54 @@ async def test_timer_flush_end_to_end(self): assert instance._entries_processed_since_last_raise == num_nutations @pytest.mark.asyncio - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync", - ) - async def test__execute_mutate_rows(self, mutate_rows): - mutate_rows.return_value = AsyncMock() - start_operation = mutate_rows().start - table = mock.Mock() - table.table_name = "test-table" - table.app_profile_id = "test-app-profile" - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - async with self._make_one(table) as instance: - batch = [self._make_mutation()] - result = await instance._execute_mutate_rows(batch) - assert start_operation.call_count == 1 - args, kwargs = mutate_rows.call_args - assert args[0] == table.client._gapic_client - assert args[1] == table - assert args[2] == batch - kwargs["operation_timeout"] == 17 - kwargs["attempt_timeout"] == 13 - assert result == [] + async def test__execute_mutate_rows(self): + mutate_path = self._get_mutate_rows_class_path() + with mock.patch(mutate_path) as mutate_rows: + mutate_rows.return_value = AsyncMock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + async with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = await instance._execute_mutate_rows(batch) + assert start_operation.call_count == 1 + args, kwargs = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + kwargs["operation_timeout"] == 17 + kwargs["attempt_timeout"] == 13 + assert result == [] @pytest.mark.asyncio - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync.start" - ) - async def test__execute_mutate_rows_returns_errors(self, mutate_rows): + async def test__execute_mutate_rows_returns_errors(self): """Errors from operation should be retruned as list""" from google.cloud.bigtable.data.exceptions import ( MutationsExceptionGroup, FailedMutationEntryError, ) - - err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) - err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) - mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - async with self._make_one(table) as instance: - batch = [self._make_mutation()] - result = await instance._execute_mutate_rows(batch) - assert len(result) == 2 - assert result[0] == err1 - assert result[1] == err2 - # indices should be set to None - assert result[0].index is None - assert result[1].index is None + cls_path = self._get_mutate_rows_class_path() + with mock.patch(f"{cls_path}.start") as mutate_rows: + err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) + err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + async with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = await instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + # indices should be set to None + assert result[0].index is None + assert result[1].index is None @pytest.mark.asyncio async def test__raise_exceptions(self): @@ -1066,29 +1048,27 @@ async def test_atexit_registration(self): assert register_mock.call_count == 1 @pytest.mark.asyncio - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync", - ) - async def test_timeout_args_passed(self, mutate_rows): + async def test_timeout_args_passed(self): """ batch_operation_timeout and batch_attempt_timeout should be used in api calls """ - mutate_rows.return_value = AsyncMock() - expected_operation_timeout = 17 - expected_attempt_timeout = 13 - async with self._make_one( - batch_operation_timeout=expected_operation_timeout, - batch_attempt_timeout=expected_attempt_timeout, - ) as instance: - assert instance._operation_timeout == expected_operation_timeout - assert instance._attempt_timeout == expected_attempt_timeout - # make simulated gapic call - await instance._execute_mutate_rows([self._make_mutation()]) - assert mutate_rows.call_count == 1 - kwargs = mutate_rows.call_args[1] - assert kwargs["operation_timeout"] == expected_operation_timeout - assert kwargs["attempt_timeout"] == expected_attempt_timeout + mutate_path = self._get_mutate_rows_class_path() + with mock.patch(mutate_path, return_value=AsyncMock()) as mutate_rows: + expected_operation_timeout = 17 + expected_attempt_timeout = 13 + async with self._make_one( + batch_operation_timeout=expected_operation_timeout, + batch_attempt_timeout=expected_attempt_timeout, + ) as instance: + assert instance._operation_timeout == expected_operation_timeout + assert instance._attempt_timeout == expected_attempt_timeout + # make simulated gapic call + await instance._execute_mutate_rows([self._make_mutation()]) + assert mutate_rows.call_count == 1 + kwargs = mutate_rows.call_args[1] + assert kwargs["operation_timeout"] == expected_operation_timeout + assert kwargs["attempt_timeout"] == expected_attempt_timeout @pytest.mark.parametrize( "limit,in_e,start_e,end_e", diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py index 4eca799da..e9692a7e7 100644 --- a/tests/unit/data/_sync/test_autogen.py +++ b/tests/unit/data/_sync/test_autogen.py @@ -16,872 +16,850 @@ from abc import ABC -from tests.unit.data._async.test__mutate_rows import TestMutateRowsOperation -from tests.unit.data._async.test__read_rows import TestReadRowsOperation -from tests.unit.data._async.test_mutations_batcher import Test_FlowControl from unittest import mock +import concurrent.futures import mock import pytest -import threading import time -from google.cloud.bigtable_v2.types import MutateRowsResponse -from google.rpc import status_pb2 +from google.cloud.bigtable.data import TABLE_DEFAULT +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete import google.api_core.exceptions as core_exceptions -class TestMutateRowsOperation(ABC): - def _target_class(self): - from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation +class TestMutationsBatcher(ABC): + def _get_target_class(self): + from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher - return _MutateRowsOperation + return MutationsBatcher - def _make_one(self, *args, **kwargs): - if not args: - kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) - kwargs["table"] = kwargs.pop("table", mock.Mock()) - kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) - kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) - kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) - kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) - return self._target_class()(*args, **kwargs) + def _get_mutate_rows_class_path(self): + return "google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation" - def _make_mutation(self, count=1, size=1): + def _make_one(self, table=None, **kwargs): + from google.api_core.exceptions import DeadlineExceeded + from google.api_core.exceptions import ServiceUnavailable + + if table is None: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 10 + table.default_mutate_rows_retryable_errors = ( + DeadlineExceeded, + ServiceUnavailable, + ) + return self._get_target_class()(table, **kwargs) + + @staticmethod + def _make_mutation(count=1, size=1): mutation = mock.Mock() mutation.size.return_value = size mutation.mutations = [mock.Mock()] * count return mutation - def _mock_stream(self, mutation_list, error_dict): - for idx, entry in enumerate(mutation_list): - code = error_dict.get(idx, 0) - yield MutateRowsResponse( - entries=[ - MutateRowsResponse.Entry( - index=idx, status=status_pb2.Status(code=code) - ) - ] - ) - - def _make_mock_gapic(self, mutation_list, error_dict=None): - mock_fn = mock.Mock() - if error_dict is None: - error_dict = {} - mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( - mutation_list, error_dict - ) - return mock_fn - - def test_ctor(self): - """test that constructor sets all the attributes correctly""" - from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto - from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - from google.api_core.exceptions import DeadlineExceeded - from google.api_core.exceptions import Aborted - - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] - operation_timeout = 0.05 - attempt_timeout = 0.01 - retryable_exceptions = () - instance = self._make_one( - client, - table, - entries, - operation_timeout, - attempt_timeout, - retryable_exceptions, - ) - assert client.mutate_rows.call_count == 0 - instance._gapic_fn() - assert client.mutate_rows.call_count == 1 - inner_kwargs = client.mutate_rows.call_args[1] - assert len(inner_kwargs) == 4 - assert inner_kwargs["table_name"] == table.table_name - assert inner_kwargs["app_profile_id"] == table.app_profile_id - assert inner_kwargs["retry"] is None - metadata = inner_kwargs["metadata"] - assert len(metadata) == 1 - assert metadata[0][0] == "x-goog-request-params" - assert str(table.table_name) in metadata[0][1] - assert str(table.app_profile_id) in metadata[0][1] - entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries] - assert instance.mutations == entries_w_pb - assert next(instance.timeout_generator) == attempt_timeout - assert instance.is_retryable is not None - assert instance.is_retryable(DeadlineExceeded("")) is False - assert instance.is_retryable(Aborted("")) is False - assert instance.is_retryable(_MutateRowsIncomplete("")) is True - assert instance.is_retryable(RuntimeError("")) is False - assert instance.remaining_indices == list(range(len(entries))) - assert instance.errors == {} - - def test_ctor_too_many_entries(self): - """should raise an error if an operation is created with more than 100,000 entries""" - from google.cloud.bigtable.data._async._mutate_rows import ( - _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, - ) - - assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100000 - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT - operation_timeout = 0.05 - attempt_timeout = 0.01 - self._make_one(client, table, entries, operation_timeout, attempt_timeout) - with pytest.raises(ValueError) as e: - self._make_one( - client, - table, - entries + [self._make_mutation()], - operation_timeout, - attempt_timeout, - ) - assert "mutate_rows requests can contain at most 100000 mutations" in str( - e.value - ) - assert "Found 100001" in str(e.value) - - def test_mutate_rows_operation(self): - """Test successful case of mutate_rows_operation""" - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] - operation_timeout = 0.05 - cls = self._target_class() - with mock.patch( - f"{cls.__module__}.{cls.__name__}._run_attempt", mock.Mock() - ) as attempt_mock: - instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout - ) - instance.start() - assert attempt_mock.call_count == 1 - - @pytest.mark.parametrize( - "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] - ) - def test_mutate_rows_attempt_exception(self, exc_type): - """exceptions raised from attempt should be raised in MutationsExceptionGroup""" - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] - operation_timeout = 0.05 - expected_exception = exc_type("test") - client.mutate_rows.side_effect = expected_exception - found_exc = None - try: - instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout - ) - instance._run_attempt() - except Exception as e: - found_exc = e - assert client.mutate_rows.call_count == 1 - assert type(found_exc) is exc_type - assert found_exc == expected_exception - assert len(instance.errors) == 2 - assert len(instance.remaining_indices) == 0 - - @pytest.mark.parametrize( - "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] - ) - def test_mutate_rows_exception(self, exc_type): - """exceptions raised from retryable should be raised in MutationsExceptionGroup""" - from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup - from google.cloud.bigtable.data.exceptions import FailedMutationEntryError - - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] - operation_timeout = 0.05 - expected_cause = exc_type("abort") + def test_ctor_defaults(self): with mock.patch.object( - self._target_class(), "_run_attempt", mock.Mock() - ) as attempt_mock: - attempt_mock.side_effect = expected_cause - found_exc = None - try: - instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout + self._get_target_class(), + "_start_flush_timer", + return_value=concurrent.futures.Future(), + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = [Exception] + with self._make_one(table) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._max_mutation_count == 100000 + assert instance._flow_control._max_mutation_bytes == 104857600 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert ( + instance._operation_timeout + == table.default_mutate_rows_operation_timeout + ) + assert ( + instance._attempt_timeout + == table.default_mutate_rows_attempt_timeout + ) + assert ( + instance._retryable_errors + == table.default_mutate_rows_retryable_errors ) - instance.start() - except MutationsExceptionGroup as e: - found_exc = e - assert attempt_mock.call_count == 1 - assert len(found_exc.exceptions) == 2 - assert isinstance(found_exc.exceptions[0], FailedMutationEntryError) - assert isinstance(found_exc.exceptions[1], FailedMutationEntryError) - assert found_exc.exceptions[0].__cause__ == expected_cause - assert found_exc.exceptions[1].__cause__ == expected_cause + time.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == 5 + assert isinstance(instance._flush_timer, concurrent.futures.Future) - @pytest.mark.parametrize( - "exc_type", [core_exceptions.DeadlineExceeded, RuntimeError] - ) - def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): - """If an exception fails but eventually passes, it should not raise an exception""" - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation()] - operation_timeout = 1 - expected_cause = exc_type("retry") - num_retries = 2 + def test_ctor_explicit(self): + """Test with explicit parameters""" with mock.patch.object( - self._target_class(), "_run_attempt", mock.Mock() - ) as attempt_mock: - attempt_mock.side_effect = [expected_cause] * num_retries + [None] - instance = self._make_one( - client, + self._get_target_class(), + "_start_flush_timer", + return_value=concurrent.futures.Future(), + ) as flush_timer_mock: + table = mock.Mock() + flush_interval = 20 + flush_limit_count = 17 + flush_limit_bytes = 19 + flow_control_max_mutation_count = 1001 + flow_control_max_bytes = 12 + operation_timeout = 11 + attempt_timeout = 2 + retryable_errors = [Exception] + with self._make_one( table, - entries, - operation_timeout, - operation_timeout, - retryable_exceptions=(exc_type,), - ) - instance.start() - assert attempt_mock.call_count == num_retries + 1 - - def test_mutate_rows_incomplete_ignored(self): - """MutateRowsIncomplete exceptions should not be added to error list""" - from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup - from google.api_core.exceptions import DeadlineExceeded - - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation()] - operation_timeout = 0.05 - with mock.patch.object( - self._target_class(), "_run_attempt", mock.Mock() - ) as attempt_mock: - attempt_mock.side_effect = _MutateRowsIncomplete("ignored") - found_exc = None - try: - instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=operation_timeout, + batch_attempt_timeout=attempt_timeout, + batch_retryable_errors=retryable_errors, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert ( + instance._flow_control._max_mutation_count + == flow_control_max_mutation_count ) - instance.start() - except MutationsExceptionGroup as e: - found_exc = e - assert attempt_mock.call_count > 0 - assert len(found_exc.exceptions) == 1 - assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) - - def test_run_attempt_single_entry_success(self): - """Test mutating a single entry""" - mutation = self._make_mutation() - expected_timeout = 1.3 - mock_gapic_fn = self._make_mock_gapic({0: mutation}) - instance = self._make_one( - mutation_entries=[mutation], attempt_timeout=expected_timeout - ) - with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): - instance._run_attempt() - assert len(instance.remaining_indices) == 0 - assert mock_gapic_fn.call_count == 1 - (_, kwargs) = mock_gapic_fn.call_args - assert kwargs["timeout"] == expected_timeout - assert kwargs["entries"] == [mutation._to_pb()] - - def test_run_attempt_empty_request(self): - """Calling with no mutations should result in no API calls""" - mock_gapic_fn = self._make_mock_gapic([]) - instance = self._make_one(mutation_entries=[]) - instance._run_attempt() - assert mock_gapic_fn.call_count == 0 - - def test_run_attempt_partial_success_retryable(self): - """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" - from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - - success_mutation = self._make_mutation() - success_mutation_2 = self._make_mutation() - failure_mutation = self._make_mutation() - mutations = [success_mutation, failure_mutation, success_mutation_2] - mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) - instance = self._make_one(mutation_entries=mutations) - instance.is_retryable = lambda x: True - with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): - with pytest.raises(_MutateRowsIncomplete): - instance._run_attempt() - assert instance.remaining_indices == [1] - assert 0 not in instance.errors - assert len(instance.errors[1]) == 1 - assert instance.errors[1][0].grpc_status_code == 300 - assert 2 not in instance.errors - - def test_run_attempt_partial_success_non_retryable(self): - """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" - success_mutation = self._make_mutation() - success_mutation_2 = self._make_mutation() - failure_mutation = self._make_mutation() - mutations = [success_mutation, failure_mutation, success_mutation_2] - mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) - instance = self._make_one(mutation_entries=mutations) - instance.is_retryable = lambda x: False - with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): - instance._run_attempt() - assert instance.remaining_indices == [] - assert 0 not in instance.errors - assert len(instance.errors[1]) == 1 - assert instance.errors[1][0].grpc_status_code == 300 - assert 2 not in instance.errors - - -class TestReadRowsOperation(ABC): - """ - Tests helper functions in the ReadRowsOperation class - in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt - is tested in test_read_rows_acceptance test_client_read_rows, and conformance tests - """ - - @staticmethod - def _get_target_class(): - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - - return _ReadRowsOperation - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - - def test_ctor(self): - from google.cloud.bigtable.data import ReadRowsQuery - - row_limit = 91 - query = ReadRowsQuery(limit=row_limit) - client = mock.Mock() - client.read_rows = mock.Mock() - client.read_rows.return_value = None - table = mock.Mock() - table._client = client - table.table_name = "test_table" - table.app_profile_id = "test_profile" - expected_operation_timeout = 42 - expected_request_timeout = 44 - time_gen_mock = mock.Mock() - with mock.patch( - "google.cloud.bigtable.data._helpers._attempt_timeout_generator", - time_gen_mock, - ): - instance = self._make_one( - query, + assert ( + instance._flow_control._max_mutation_bytes == flow_control_max_bytes + ) + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert instance._operation_timeout == operation_timeout + assert instance._attempt_timeout == attempt_timeout + assert instance._retryable_errors == retryable_errors + time.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == flush_interval + assert isinstance(instance._flush_timer, concurrent.futures.Future) + + def test_ctor_no_flush_limits(self): + """Test with None for flush limits""" + with mock.patch.object( + self._get_target_class(), + "_start_flush_timer", + return_value=concurrent.futures.Future(), + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = () + flush_interval = None + flush_limit_count = None + flush_limit_bytes = None + with self._make_one( table, - operation_timeout=expected_operation_timeout, - attempt_timeout=expected_request_timeout, - ) - assert time_gen_mock.call_count == 1 - time_gen_mock.assert_called_once_with( - expected_request_timeout, expected_operation_timeout - ) - assert instance._last_yielded_row_key is None - assert instance._remaining_count == row_limit - assert instance.operation_timeout == expected_operation_timeout - assert client.read_rows.call_count == 0 - assert instance._metadata == [ - ( - "x-goog-request-params", - "table_name=test_table&app_profile_id=test_profile", - ) - ] - assert instance.request.table_name == table.table_name - assert instance.request.app_profile_id == table.app_profile_id - assert instance.request.rows_limit == row_limit - - @pytest.mark.parametrize( - "in_keys,last_key,expected", - [ - (["b", "c", "d"], "a", ["b", "c", "d"]), - (["a", "b", "c"], "b", ["c"]), - (["a", "b", "c"], "c", []), - (["a", "b", "c"], "d", []), - (["d", "c", "b", "a"], "b", ["d", "c"]), - ], - ) - def test_revise_request_rowset_keys(self, in_keys, last_key, expected): - from google.cloud.bigtable_v2.types import RowSet as RowSetPB - from google.cloud.bigtable_v2.types import RowRange as RowRangePB - - in_keys = [key.encode("utf-8") for key in in_keys] - expected = [key.encode("utf-8") for key in expected] - last_key = last_key.encode("utf-8") - sample_range = RowRangePB(start_key_open=last_key) - row_set = RowSetPB(row_keys=in_keys, row_ranges=[sample_range]) - revised = self._get_target_class()._revise_request_rowset(row_set, last_key) - assert revised.row_keys == expected - assert revised.row_ranges == [sample_range] + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._staged_entries == [] + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + time.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] is None + assert isinstance(instance._flush_timer, concurrent.futures.Future) - @pytest.mark.parametrize( - "in_ranges,last_key,expected", - [ - ( - [{"start_key_open": "b", "end_key_closed": "d"}], - "a", - [{"start_key_open": "b", "end_key_closed": "d"}], - ), - ( - [{"start_key_closed": "b", "end_key_closed": "d"}], - "a", - [{"start_key_closed": "b", "end_key_closed": "d"}], - ), - ( - [{"start_key_open": "a", "end_key_closed": "d"}], - "b", - [{"start_key_open": "b", "end_key_closed": "d"}], - ), - ( - [{"start_key_closed": "a", "end_key_open": "d"}], - "b", - [{"start_key_open": "b", "end_key_open": "d"}], - ), - ( - [{"start_key_closed": "b", "end_key_closed": "d"}], - "b", - [{"start_key_open": "b", "end_key_closed": "d"}], - ), - ([{"start_key_closed": "b", "end_key_closed": "d"}], "d", []), - ([{"start_key_closed": "b", "end_key_open": "d"}], "d", []), - ([{"start_key_closed": "b", "end_key_closed": "d"}], "e", []), - ([{"start_key_closed": "b"}], "z", [{"start_key_open": "z"}]), - ([{"start_key_closed": "b"}], "a", [{"start_key_closed": "b"}]), - ( - [{"end_key_closed": "z"}], - "a", - [{"start_key_open": "a", "end_key_closed": "z"}], - ), - ( - [{"end_key_open": "z"}], - "a", - [{"start_key_open": "a", "end_key_open": "z"}], - ), - ], - ) - def test_revise_request_rowset_ranges(self, in_ranges, last_key, expected): - from google.cloud.bigtable_v2.types import RowSet as RowSetPB - from google.cloud.bigtable_v2.types import RowRange as RowRangePB - - next_key = (last_key + "a").encode("utf-8") - last_key = last_key.encode("utf-8") - in_ranges = [ - RowRangePB(**{k: v.encode("utf-8") for (k, v) in r.items()}) - for r in in_ranges - ] - expected = [ - RowRangePB(**{k: v.encode("utf-8") for (k, v) in r.items()}) - for r in expected - ] - row_set = RowSetPB(row_ranges=in_ranges, row_keys=[next_key]) - revised = self._get_target_class()._revise_request_rowset(row_set, last_key) - assert revised.row_keys == [next_key] - assert revised.row_ranges == expected - - @pytest.mark.parametrize("last_key", ["a", "b", "c"]) - def test_revise_request_full_table(self, last_key): - from google.cloud.bigtable_v2.types import RowSet as RowSetPB - from google.cloud.bigtable_v2.types import RowRange as RowRangePB - - last_key = last_key.encode("utf-8") - row_set = RowSetPB() - for selected_set in [row_set, None]: - revised = self._get_target_class()._revise_request_rowset( - selected_set, last_key - ) - assert revised.row_keys == [] - assert len(revised.row_ranges) == 1 - assert revised.row_ranges[0] == RowRangePB(start_key_open=last_key) - - def test_revise_to_empty_rowset(self): - """revising to an empty rowset should raise error""" - from google.cloud.bigtable.data.exceptions import _RowSetComplete - from google.cloud.bigtable_v2.types import RowSet as RowSetPB - from google.cloud.bigtable_v2.types import RowRange as RowRangePB - - row_keys = [b"a", b"b", b"c"] - row_range = RowRangePB(end_key_open=b"c") - row_set = RowSetPB(row_keys=row_keys, row_ranges=[row_range]) - with pytest.raises(_RowSetComplete): - self._get_target_class()._revise_request_rowset(row_set, b"d") + def test_ctor_invalid_values(self): + """Test that timeout values are positive, and fit within expected limits""" + with pytest.raises(ValueError) as e: + self._make_one(batch_operation_timeout=-1) + assert "operation_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._make_one(batch_attempt_timeout=-1) + assert "attempt_timeout must be greater than 0" in str(e.value) - @pytest.mark.parametrize( - "start_limit,emit_num,expected_limit", - [ - (10, 0, 10), - (10, 1, 9), - (10, 10, 0), - (None, 10, None), - (None, 0, None), - (4, 2, 2), - ], - ) - def test_revise_limit(self, start_limit, emit_num, expected_limit): + def test_default_argument_consistency(self): """ - revise_limit should revise the request's limit field - - if limit is 0 (unlimited), it should never be revised - - if start_limit-emit_num == 0, the request should end early - - if the number emitted exceeds the new limit, an exception should - should be raised (tested in test_revise_limit_over_limit) + We supply default arguments in MutationsBatcherAsync.__init__, and in + table.mutations_batcher. Make sure any changes to defaults are applied to + both places """ - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable_v2.types import ReadRowsResponse - - def awaitable_stream(): - def mock_stream(): - for i in range(emit_num): - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk( - row_key=str(i).encode(), - family_name="b", - qualifier=b"c", - value=b"d", - commit_row=True, - ) - ] - ) - - return mock_stream() + from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) + import inspect - query = ReadRowsQuery(limit=start_limit) - table = mock.Mock() - table.table_name = "table_name" - table.app_profile_id = "app_profile_id" - instance = self._make_one(query, table, 10, 10) - assert instance._remaining_count == start_limit - for val in instance.chunk_stream(awaitable_stream()): - pass - assert instance._remaining_count == expected_limit + get_batcher_signature = dict( + inspect.signature(TableAsync.mutations_batcher).parameters + ) + get_batcher_signature.pop("self") + batcher_init_signature = dict( + inspect.signature(MutationsBatcherAsync).parameters + ) + batcher_init_signature.pop("table") + assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) + assert len(get_batcher_signature) == 8 + assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) + for arg_name in get_batcher_signature.keys(): + assert ( + get_batcher_signature[arg_name].default + == batcher_init_signature[arg_name].default + ) - @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) - def test_revise_limit_over_limit(self, start_limit, emit_num): + def test__start_flush_timer_w_None(self): + """Empty timer should return immediately""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + with self._make_one() as instance: + with mock.patch("asyncio.sleep") as sleep_mock: + instance._start_flush_timer(None) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + + def test__start_flush_timer_call_when_closed(self): + """closed batcher's timer should return immediately""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + with self._make_one() as instance: + instance.close() + flush_mock.reset_mock() + with mock.patch("asyncio.sleep") as sleep_mock: + instance._start_flush_timer(1) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + + def test__flush_timer(self): + """Timer should continue to call _schedule_flush in a loop""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + expected_sleep = 12 + with self._make_one(flush_interval=expected_sleep) as instance: + instance._staged_entries = [mock.Mock()] + loop_num = 3 + with mock.patch("asyncio.sleep") as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [ + ZeroDivisionError("expected") + ] + try: + instance._flush_timer + except ZeroDivisionError: + instance._flush_timer = concurrent.futures.Future() + assert sleep_mock.call_count == loop_num + 1 + sleep_mock.assert_called_with(expected_sleep) + assert flush_mock.call_count == loop_num + + def test__flush_timer_no_mutations(self): + """Timer should not flush if no new mutations have been staged""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + expected_sleep = 12 + with self._make_one(flush_interval=expected_sleep) as instance: + loop_num = 3 + with mock.patch("asyncio.sleep") as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] + try: + instance._flush_timer + except TabError: + instance._flush_timer = concurrent.futures.Future() + assert sleep_mock.call_count == loop_num + 1 + sleep_mock.assert_called_with(expected_sleep) + assert flush_mock.call_count == 0 + + def test__flush_timer_close(self): + """Timer should continue terminate after close""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + with self._make_one() as instance: + with mock.patch("asyncio.sleep"): + time.sleep(0.5) + assert instance._flush_timer.done() is False + instance.close() + time.sleep(0.1) + assert instance._flush_timer.done() is True + + def test_append_closed(self): + """Should raise exception""" + with pytest.raises(RuntimeError): + instance = self._make_one() + instance.close() + instance.append(mock.Mock()) + + def test_append_wrong_mutation(self): """ - Should raise runtime error if we get in state where emit_num > start_num - (unless start_num == 0, which represents unlimited) + Mutation objects should raise an exception. + Only support RowMutationEntry """ - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable_v2.types import ReadRowsResponse - from google.cloud.bigtable.data.exceptions import InvalidChunk - - def awaitable_stream(): - def mock_stream(): - for i in range(emit_num): - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk( - row_key=str(i).encode(), - family_name="b", - qualifier=b"c", - value=b"d", - commit_row=True, - ) - ] - ) - - return mock_stream() - - query = ReadRowsQuery(limit=start_limit) - table = mock.Mock() - table.table_name = "table_name" - table.app_profile_id = "app_profile_id" - instance = self._make_one(query, table, 10, 10) - assert instance._remaining_count == start_limit - with pytest.raises(InvalidChunk) as e: - for val in instance.chunk_stream(awaitable_stream()): - pass - assert "emit count exceeds row limit" in str(e.value) - - def test_close(self): + from google.cloud.bigtable.data.mutations import DeleteAllFromRow + + with self._make_one() as instance: + expected_error = "invalid mutation type: DeleteAllFromRow. Only RowMutationEntry objects are supported by batcher" + with pytest.raises(ValueError) as e: + instance.append(DeleteAllFromRow()) + assert str(e.value) == expected_error + + def test_append_outside_flow_limits(self): + """entries larger than mutation limits are still processed""" + with self._make_one( + flow_control_max_mutation_count=1, flow_control_max_bytes=1 + ) as instance: + oversized_entry = self._make_mutation(count=0, size=2) + instance.append(oversized_entry) + assert instance._staged_entries == [oversized_entry] + assert instance._staged_count == 0 + assert instance._staged_bytes == 2 + instance._staged_entries = [] + with self._make_one( + flow_control_max_mutation_count=1, flow_control_max_bytes=1 + ) as instance: + overcount_entry = self._make_mutation(count=2, size=0) + instance.append(overcount_entry) + assert instance._staged_entries == [overcount_entry] + assert instance._staged_count == 2 + assert instance._staged_bytes == 0 + instance._staged_entries = [] + + def test_append_flush_runs_after_limit_hit(self): """ - should be able to close a stream safely with aclose. - Closed generators should raise StopIteration on next yield + If the user appends a bunch of entries above the flush limits back-to-back, + it should still flush in a single task """ - - def mock_stream(): - while True: - yield 1 - with mock.patch.object( - self._get_target_class(), "_read_rows_attempt" - ) as mock_attempt: - instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) - wrapped_gen = mock_stream() - mock_attempt.return_value = wrapped_gen - gen = instance.start_operation() - gen.__next__() - gen.close() - with pytest.raises(StopIteration): - gen.__next__() - gen.close() - with pytest.raises(StopIteration): - wrapped_gen.__next__() - - def test_retryable_ignore_repeated_rows(self): - """Duplicate rows should cause an invalid chunk error""" - from google.cloud.bigtable.data.exceptions import InvalidChunk - from google.cloud.bigtable_v2.types import ReadRowsResponse - - row_key = b"duplicate" - - def mock_awaitable_stream(): - def mock_stream(): - while True: - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) - ] - ) - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) - ] - ) - - return mock_stream() - - instance = mock.Mock() - instance._last_yielded_row_key = None - instance._remaining_count = None - stream = self._get_target_class().chunk_stream( - instance, mock_awaitable_stream() - ) - stream.__next__() - with pytest.raises(InvalidChunk) as exc: - stream.__next__() - assert "row keys should be strictly increasing" in str(exc.value) - - -class Test_FlowControl(ABC): - @staticmethod - def _target_class(): - from google.cloud.bigtable.data._sync.mutations_batcher import _FlowControl - - return _FlowControl - - def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): - return self._target_class()(max_mutation_count, max_mutation_bytes) - - @staticmethod - def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation - - def test_ctor(self): - max_mutation_count = 9 - max_mutation_bytes = 19 - instance = self._make_one(max_mutation_count, max_mutation_bytes) - assert instance._max_mutation_count == max_mutation_count - assert instance._max_mutation_bytes == max_mutation_bytes - assert instance._in_flight_mutation_count == 0 - assert instance._in_flight_mutation_bytes == 0 - assert isinstance(instance._capacity_condition, threading.Condition) - - def test_ctor_invalid_values(self): - """Test that values are positive, and fit within expected limits""" - with pytest.raises(ValueError) as e: - self._make_one(0, 1) - assert "max_mutation_count must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - self._make_one(1, 0) - assert "max_mutation_bytes must be greater than 0" in str(e.value) + self._get_target_class(), "_execute_mutate_rows" + ) as op_mock: + with self._make_one(flush_limit_bytes=100) as instance: + + def mock_call(*args, **kwargs): + return [] + + op_mock.side_effect = mock_call + instance.append(self._make_mutation(size=99)) + num_entries = 10 + for _ in range(num_entries): + instance.append(self._make_mutation(size=1)) + print(*instance._flush_jobs) + assert op_mock.call_count == 1 + sent_batch = op_mock.call_args[0][0] + assert len(sent_batch) == 2 + assert len(instance._staged_entries) == num_entries - 1 @pytest.mark.parametrize( - "max_count,max_size,existing_count,existing_size,new_count,new_size,expected", + "flush_count,flush_bytes,mutation_count,mutation_bytes,expect_flush", [ - (1, 1, 0, 0, 0, 0, True), - (1, 1, 1, 1, 1, 1, False), - (10, 10, 0, 0, 0, 0, True), - (10, 10, 0, 0, 9, 9, True), - (10, 10, 0, 0, 11, 9, True), - (10, 10, 0, 1, 11, 9, True), - (10, 10, 1, 0, 11, 9, False), - (10, 10, 0, 0, 9, 11, True), - (10, 10, 1, 0, 9, 11, True), - (10, 10, 0, 1, 9, 11, False), - (10, 1, 0, 0, 1, 0, True), - (1, 10, 0, 0, 0, 8, True), - (float("inf"), float("inf"), 0, 0, 10000000000.0, 10000000000.0, True), - (8, 8, 0, 0, 10000000000.0, 10000000000.0, True), - (12, 12, 6, 6, 5, 5, True), - (12, 12, 5, 5, 6, 6, True), - (12, 12, 6, 6, 6, 6, True), - (12, 12, 6, 6, 7, 7, False), - (12, 12, 0, 0, 13, 13, True), - (12, 12, 12, 0, 0, 13, True), - (12, 12, 0, 12, 13, 0, True), - (12, 12, 1, 1, 13, 13, False), - (12, 12, 1, 1, 0, 13, False), - (12, 12, 1, 1, 13, 0, False), + (10, 10, 1, 1, False), + (10, 10, 9, 9, False), + (10, 10, 10, 1, True), + (10, 10, 1, 10, True), + (10, 10, 10, 10, True), + (1, 1, 10, 10, True), + (1, 1, 0, 0, False), ], ) - def test__has_capacity( - self, - max_count, - max_size, - existing_count, - existing_size, - new_count, - new_size, - expected, + def test_append( + self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush ): - """_has_capacity should return True if the new mutation will will not exceed the max count or size""" - instance = self._make_one(max_count, max_size) - instance._in_flight_mutation_count = existing_count - instance._in_flight_mutation_bytes = existing_size - assert instance._has_capacity(new_count, new_size) == expected + """test appending different mutations, and checking if it causes a flush""" + with self._make_one( + flush_limit_mutation_count=flush_count, flush_limit_bytes=flush_bytes + ) as instance: + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert instance._staged_entries == [] + mutation = self._make_mutation(count=mutation_count, size=mutation_bytes) + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + instance.append(mutation) + assert flush_mock.call_count == bool(expect_flush) + assert instance._staged_count == mutation_count + assert instance._staged_bytes == mutation_bytes + assert instance._staged_entries == [mutation] + instance._staged_entries = [] + + def test_append_multiple_sequentially(self): + """Append multiple mutations""" + with self._make_one( + flush_limit_mutation_count=8, flush_limit_bytes=8 + ) as instance: + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert instance._staged_entries == [] + mutation = self._make_mutation(count=2, size=3) + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + instance.append(mutation) + assert flush_mock.call_count == 0 + assert instance._staged_count == 2 + assert instance._staged_bytes == 3 + assert len(instance._staged_entries) == 1 + instance.append(mutation) + assert flush_mock.call_count == 0 + assert instance._staged_count == 4 + assert instance._staged_bytes == 6 + assert len(instance._staged_entries) == 2 + instance.append(mutation) + assert flush_mock.call_count == 1 + assert instance._staged_count == 6 + assert instance._staged_bytes == 9 + assert len(instance._staged_entries) == 3 + instance._staged_entries = [] + + def test_flush_flow_control_concurrent_requests(self): + """requests should happen in parallel if flow control breaks up single flush into batches""" + import time + + num_calls = 10 + fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] + with self._make_one(flow_control_max_mutation_count=1) as instance: + with mock.patch.object( + instance, "_execute_mutate_rows", mock.Mock() + ) as op_mock: + + def mock_call(*args, **kwargs): + time.sleep(0.1) + return [] + + op_mock.side_effect = mock_call + start_time = time.monotonic() + instance._staged_entries = fake_mutations + instance._schedule_flush() + time.sleep(0.01) + for i in range(num_calls): + instance._flow_control.remove_from_flow( + [self._make_mutation(count=1)] + ) + time.sleep(0.01) + print(*instance._flush_jobs) + duration = time.monotonic() - start_time + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert duration < 0.5 + assert op_mock.call_count == num_calls + + def test_schedule_flush_no_mutations(self): + """schedule flush should return None if no staged mutations""" + with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal") as flush_mock: + for i in range(3): + assert instance._schedule_flush() is None + assert flush_mock.call_count == 0 + + def test_schedule_flush_with_mutations(self): + """if new mutations exist, should add a new flush task to _flush_jobs""" + with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal") as flush_mock: + for i in range(1, 4): + mutation = mock.Mock() + instance._staged_entries = [mutation] + instance._schedule_flush() + assert instance._staged_entries == [] + time.sleep(0) + assert instance._staged_entries == [] + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert flush_mock.call_count == i + + def test__flush_internal(self): + """ + _flush_internal should: + - await previous flush call + - delegate batching to _flow_control + - call _execute_mutate_rows on each batch + - update self.exceptions and self._entries_processed_since_last_raise + """ + num_entries = 10 + with self._make_one() as instance: + with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: + with mock.patch.object( + instance._flow_control, "add_to_flow" + ) as flow_mock: + + def gen(x): + yield x + + flow_mock.side_effect = lambda x: gen(x) + mutations = [self._make_mutation(count=1, size=1)] * num_entries + instance._flush_internal(mutations) + assert instance._entries_processed_since_last_raise == num_entries + assert execute_mock.call_count == 1 + assert flow_mock.call_count == 1 + instance._oldest_exceptions.clear() + instance._newest_exceptions.clear() + + def test_flush_clears_job_list(self): + """ + a job should be added to _flush_jobs when _schedule_flush is called, + and removed when it completes + """ + with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal", mock.Mock()): + mutations = [self._make_mutation(count=1, size=1)] + instance._staged_entries = mutations + assert instance._flush_jobs == set() + new_job = instance._schedule_flush() + assert instance._flush_jobs == {new_job} + new_job + assert instance._flush_jobs == set() @pytest.mark.parametrize( - "existing_count,existing_size,added_count,added_size,new_count,new_size", + "num_starting,num_new_errors,expected_total_errors", [ - (0, 0, 0, 0, 0, 0), - (2, 2, 1, 1, 1, 1), - (2, 0, 1, 0, 1, 0), - (0, 2, 0, 1, 0, 1), - (10, 10, 0, 0, 10, 10), - (10, 10, 5, 5, 5, 5), - (0, 0, 1, 1, -1, -1), + (0, 0, 0), + (0, 1, 1), + (0, 2, 2), + (1, 0, 1), + (1, 1, 2), + (10, 2, 12), + (10, 20, 20), ], ) - def test_remove_from_flow_value_update( - self, - existing_count, - existing_size, - added_count, - added_size, - new_count, - new_size, + def test__flush_internal_with_errors( + self, num_starting, num_new_errors, expected_total_errors ): - """completed mutations should lower the inflight values""" - instance = self._make_one() - instance._in_flight_mutation_count = existing_count - instance._in_flight_mutation_bytes = existing_size - mutation = self._make_mutation(added_count, added_size) - instance.remove_from_flow(mutation) - assert instance._in_flight_mutation_count == new_count - assert instance._in_flight_mutation_bytes == new_size - - def test__remove_from_flow_unlock(self): - """capacity condition should notify after mutation is complete""" - import inspect - - instance = self._make_one(10, 10) - instance._in_flight_mutation_count = 10 - instance._in_flight_mutation_bytes = 10 - - def task_routine(): - with instance._capacity_condition: - instance._capacity_condition.wait_for( - lambda: instance._has_capacity(1, 1) + """errors returned from _execute_mutate_rows should be added to internal exceptions""" + from google.cloud.bigtable.data import exceptions + + num_entries = 10 + expected_errors = [ + exceptions.FailedMutationEntryError(mock.Mock(), mock.Mock(), ValueError()) + ] * num_new_errors + with self._make_one() as instance: + instance._oldest_exceptions = [mock.Mock()] * num_starting + with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: + execute_mock.return_value = expected_errors + with mock.patch.object( + instance._flow_control, "add_to_flow" + ) as flow_mock: + + def gen(x): + yield x + + flow_mock.side_effect = lambda x: gen(x) + mutations = [self._make_mutation(count=1, size=1)] * num_entries + instance._flush_internal(mutations) + assert instance._entries_processed_since_last_raise == num_entries + assert execute_mock.call_count == 1 + assert flow_mock.call_count == 1 + found_exceptions = instance._oldest_exceptions + list( + instance._newest_exceptions + ) + assert len(found_exceptions) == expected_total_errors + for i in range(num_starting, expected_total_errors): + assert found_exceptions[i] == expected_errors[i - num_starting] + assert found_exceptions[i].index is None + instance._oldest_exceptions.clear() + instance._newest_exceptions.clear() + + def _mock_gapic_return(self, num=5): + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc import status_pb2 + + def gen(num): + for i in range(num): + entry = MutateRowsResponse.Entry( + index=i, status=status_pb2.Status(code=0) + ) + yield MutateRowsResponse(entries=[entry]) + + return gen(num) + + def test_timer_flush_end_to_end(self): + """Flush should automatically trigger after flush_interval""" + num_nutations = 10 + mutations = [self._make_mutation(count=2, size=2)] * num_nutations + with self._make_one(flush_interval=0.05) as instance: + instance._table.default_operation_timeout = 10 + instance._table.default_attempt_timeout = 9 + with mock.patch.object( + instance._table.client._gapic_client, "mutate_rows" + ) as gapic_mock: + gapic_mock.side_effect = ( + lambda *args, **kwargs: self._mock_gapic_return(num_nutations) ) + for m in mutations: + instance.append(m) + assert instance._entries_processed_since_last_raise == 0 + time.sleep(0.1) + assert instance._entries_processed_since_last_raise == num_nutations + + def test__execute_mutate_rows(self): + mutate_path = self._get_mutate_rows_class_path() + with mock.patch(mutate_path) as mutate_rows: + mutate_rows.return_value = mock.Mock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = instance._execute_mutate_rows(batch) + assert start_operation.call_count == 1 + (args, kwargs) = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + kwargs["operation_timeout"] == 17 + kwargs["attempt_timeout"] == 13 + assert result == [] + + def test__execute_mutate_rows_returns_errors(self): + """Errors from operation should be retruned as list""" + from google.cloud.bigtable.data.exceptions import ( + MutationsExceptionGroup, + FailedMutationEntryError, + ) + + cls_path = self._get_mutate_rows_class_path() + with mock.patch(f"{cls_path}.start") as mutate_rows: + err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) + err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + assert result[0].index is None + assert result[1].index is None + + def test__raise_exceptions(self): + """Raise exceptions and reset error state""" + from google.cloud.bigtable.data import exceptions + + expected_total = 1201 + expected_exceptions = [RuntimeError("mock")] * 3 + with self._make_one() as instance: + instance._oldest_exceptions = expected_exceptions + instance._entries_processed_since_last_raise = expected_total + try: + instance._raise_exceptions() + except exceptions.MutationsExceptionGroup as exc: + assert list(exc.exceptions) == expected_exceptions + assert str(expected_total) in str(exc) + assert instance._entries_processed_since_last_raise == 0 + (instance._oldest_exceptions, instance._newest_exceptions) = ([], []) + instance._raise_exceptions() + + def test___aenter__(self): + """Should return self""" + with self._make_one() as instance: + assert instance.__enter__() == instance + + def test___aexit__(self): + """aexit should call close""" + with self._make_one() as instance: + with mock.patch.object(instance, "close") as close_mock: + instance.__exit__(None, None, None) + assert close_mock.call_count == 1 - if inspect.iscoroutinefunction(task_routine): - task = threading.Thread(task_routine()) - task_alive = lambda: not task.done() - else: - import threading - - thread = threading.Thread(target=task_routine) - thread.start() - task_alive = thread.is_alive - time.sleep(0.05) - assert task_alive() is True - mutation = self._make_mutation(count=0, size=5) - instance.remove_from_flow([mutation]) - time.sleep(0.05) - assert instance._in_flight_mutation_count == 10 - assert instance._in_flight_mutation_bytes == 5 - assert task_alive() is True - instance._in_flight_mutation_bytes = 10 - mutation = self._make_mutation(count=5, size=0) - instance.remove_from_flow([mutation]) - time.sleep(0.05) - assert instance._in_flight_mutation_count == 5 - assert instance._in_flight_mutation_bytes == 10 - assert task_alive() is True - instance._in_flight_mutation_count = 10 - mutation = self._make_mutation(count=5, size=5) - instance.remove_from_flow([mutation]) - time.sleep(0.05) - assert instance._in_flight_mutation_count == 5 - assert instance._in_flight_mutation_bytes == 5 - assert task_alive() is False + def test_close(self): + """Should clean up all resources""" + with self._make_one() as instance: + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + with mock.patch.object(instance, "_raise_exceptions") as raise_mock: + instance.close() + assert instance.closed is True + assert instance._flush_timer.done() is True + assert instance._flush_jobs == set() + assert flush_mock.call_count == 1 + assert raise_mock.call_count == 1 + + def test_close_w_exceptions(self): + """Raise exceptions on close""" + from google.cloud.bigtable.data import exceptions + + expected_total = 10 + expected_exceptions = [RuntimeError("mock")] + with self._make_one() as instance: + instance._oldest_exceptions = expected_exceptions + instance._entries_processed_since_last_raise = expected_total + try: + instance.close() + except exceptions.MutationsExceptionGroup as exc: + assert list(exc.exceptions) == expected_exceptions + assert str(expected_total) in str(exc) + assert instance._entries_processed_since_last_raise == 0 + (instance._oldest_exceptions, instance._newest_exceptions) = ([], []) + + def test__on_exit(self, recwarn): + """Should raise warnings if unflushed mutations exist""" + with self._make_one() as instance: + instance._on_exit() + assert len(recwarn) == 0 + num_left = 4 + instance._staged_entries = [mock.Mock()] * num_left + with pytest.warns(UserWarning) as w: + instance._on_exit() + assert len(w) == 1 + assert "unflushed mutations" in str(w[0].message).lower() + assert str(num_left) in str(w[0].message) + instance.closed = True + instance._on_exit() + assert len(recwarn) == 0 + instance._staged_entries = [] + + def test_atexit_registration(self): + """Should run _on_exit on program termination""" + import atexit + + with mock.patch.object(atexit, "register") as register_mock: + assert register_mock.call_count == 0 + with self._make_one(): + assert register_mock.call_count == 1 + + def test_timeout_args_passed(self): + """ + batch_operation_timeout and batch_attempt_timeout should be used + in api calls + """ + mutate_path = self._get_mutate_rows_class_path() + with mock.patch(mutate_path, return_value=mock.Mock()) as mutate_rows: + expected_operation_timeout = 17 + expected_attempt_timeout = 13 + with self._make_one( + batch_operation_timeout=expected_operation_timeout, + batch_attempt_timeout=expected_attempt_timeout, + ) as instance: + assert instance._operation_timeout == expected_operation_timeout + assert instance._attempt_timeout == expected_attempt_timeout + instance._execute_mutate_rows([self._make_mutation()]) + assert mutate_rows.call_count == 1 + kwargs = mutate_rows.call_args[1] + assert kwargs["operation_timeout"] == expected_operation_timeout + assert kwargs["attempt_timeout"] == expected_attempt_timeout @pytest.mark.parametrize( - "mutations,count_cap,size_cap,expected_results", + "limit,in_e,start_e,end_e", [ - ([(5, 5), (1, 1), (1, 1)], 10, 10, [[(5, 5), (1, 1), (1, 1)]]), - ([(1, 1), (1, 1), (1, 1)], 1, 1, [[(1, 1)], [(1, 1)], [(1, 1)]]), - ([(1, 1), (1, 1), (1, 1)], 2, 10, [[(1, 1), (1, 1)], [(1, 1)]]), - ([(1, 1), (1, 1), (1, 1)], 10, 2, [[(1, 1), (1, 1)], [(1, 1)]]), - ( - [(1, 1), (5, 5), (4, 1), (1, 4), (1, 1)], - 5, - 5, - [[(1, 1)], [(5, 5)], [(4, 1), (1, 4)], [(1, 1)]], - ), + (10, 0, (10, 0), (10, 0)), + (1, 10, (0, 0), (1, 1)), + (10, 1, (0, 0), (1, 0)), + (10, 10, (0, 0), (10, 0)), + (10, 11, (0, 0), (10, 1)), + (3, 20, (0, 0), (3, 3)), + (10, 20, (0, 0), (10, 10)), + (10, 21, (0, 0), (10, 10)), + (2, 1, (2, 0), (2, 1)), + (2, 1, (1, 0), (2, 0)), + (2, 2, (1, 0), (2, 1)), + (3, 1, (3, 1), (3, 2)), + (3, 3, (3, 1), (3, 3)), + (1000, 5, (999, 0), (1000, 4)), + (1000, 5, (0, 0), (5, 0)), + (1000, 5, (1000, 0), (1000, 5)), ], ) - def test_add_to_flow(self, mutations, count_cap, size_cap, expected_results): - """Test batching with various flow control settings""" - mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] - instance = self._make_one(count_cap, size_cap) - i = 0 - for batch in instance.add_to_flow(mutation_objs): - expected_batch = expected_results[i] - assert len(batch) == len(expected_batch) - for j in range(len(expected_batch)): - assert len(batch[j].mutations) == expected_batch[j][0] - assert batch[j].size() == expected_batch[j][1] - instance.remove_from_flow(batch) - i += 1 - assert i == len(expected_results) + def test__add_exceptions(self, limit, in_e, start_e, end_e): + """ + Test that the _add_exceptions function properly updates the + _oldest_exceptions and _newest_exceptions lists + Args: + - limit: the _exception_list_limit representing the max size of either list + - in_e: size of list of exceptions to send to _add_exceptions + - start_e: a tuple of ints representing the initial sizes of _oldest_exceptions and _newest_exceptions + - end_e: a tuple of ints representing the expected sizes of _oldest_exceptions and _newest_exceptions + """ + from collections import deque + + input_list = [RuntimeError(f"mock {i}") for i in range(in_e)] + mock_batcher = mock.Mock() + mock_batcher._oldest_exceptions = [ + RuntimeError(f"starting mock {i}") for i in range(start_e[0]) + ] + mock_batcher._newest_exceptions = deque( + [RuntimeError(f"starting mock {i}") for i in range(start_e[1])], + maxlen=limit, + ) + mock_batcher._exception_list_limit = limit + mock_batcher._exceptions_since_last_raise = 0 + self._get_target_class()._add_exceptions(mock_batcher, input_list) + assert len(mock_batcher._oldest_exceptions) == end_e[0] + assert len(mock_batcher._newest_exceptions) == end_e[1] + assert mock_batcher._exceptions_since_last_raise == in_e + oldest_list_diff = end_e[0] - start_e[0] + newest_list_diff = min(max(in_e - oldest_list_diff, 0), limit) + for i in range(oldest_list_diff): + assert mock_batcher._oldest_exceptions[i + start_e[0]] == input_list[i] + for i in range(1, newest_list_diff + 1): + assert mock_batcher._newest_exceptions[-i] == input_list[-i] @pytest.mark.parametrize( - "mutations,max_limit,expected_results", + "input_retryables,expected_retryables", [ - ([(1, 1)] * 11, 10, [[(1, 1)] * 10, [(1, 1)]]), - ([(1, 1)] * 10, 1, [[(1, 1)] for _ in range(10)]), - ([(1, 1)] * 10, 2, [[(1, 1), (1, 1)] for _ in range(5)]), + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), ], ) - def test_add_to_flow_max_mutation_limits( - self, mutations, max_limit, expected_results - ): + def test_customizable_retryable_errors(self, input_retryables, expected_retryables): """ - Test flow control running up against the max API limit - Should submit request early, even if the flow control has room for more + Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer. """ - async_patch = mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", - max_limit, - ) - sync_patch = mock.patch( - "google.cloud.bigtable.data._sync._autogen._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", - max_limit, - ) - with async_patch, sync_patch: - mutation_objs = [ - self._make_mutation(count=m[0], size=m[1]) for m in mutations - ] - instance = self._make_one(float("inf"), float("inf")) - i = 0 - for batch in instance.add_to_flow(mutation_objs): - expected_batch = expected_results[i] - assert len(batch) == len(expected_batch) - for j in range(len(expected_batch)): - assert len(batch[j].mutations) == expected_batch[j][0] - assert batch[j].size() == expected_batch[j][1] - instance.remove_from_flow(batch) - i += 1 - assert i == len(expected_results) - - def test_add_to_flow_oversize(self): - """mutations over the flow control limits should still be accepted""" - instance = self._make_one(2, 3) - large_size_mutation = self._make_mutation(count=1, size=10) - large_count_mutation = self._make_mutation(count=10, size=1) - results = [out for out in instance.add_to_flow([large_size_mutation])] - assert len(results) == 1 - instance.remove_from_flow(results[0]) - count_results = [out for out in instance.add_to_flow(large_count_mutation)] - assert len(count_results) == 1 + from google.cloud.bigtable.data._async.client import TableAsync + + with mock.patch( + "google.api_core.retry.if_exception_type" + ) as predicate_builder_mock: + with mock.patch( + "google.api_core.retry.retry_target_async" + ) as retry_fn_mock: + table = None + with mock.patch("asyncio.create_task"): + table = TableAsync(mock.Mock(), "instance", "table") + with self._make_one( + table, batch_retryable_errors=input_retryables + ) as instance: + assert instance._retryable_errors == expected_retryables + expected_predicate = lambda a: a in expected_retryables + predicate_builder_mock.return_value = expected_predicate + retry_fn_mock.side_effect = RuntimeError("stop early") + mutation = self._make_mutation(count=1, size=1) + instance._execute_mutate_rows([mutation]) + predicate_builder_mock.assert_called_once_with( + *expected_retryables, _MutateRowsIncomplete + ) + retry_call_args = retry_fn_mock.call_args_list[0].args + assert retry_call_args[1] is expected_predicate From aa472c13931886b42b7a995af6d4cbf2056870af Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Apr 2024 13:28:36 -0700 Subject: [PATCH 033/360] fixed TableAssync references --- google/cloud/bigtable/data/__init__.py | 12 ++++++------ sync_surface_generator.py | 15 ++++++++------- tests/unit/data/_async/test_mutations_batcher.py | 9 ++------- tests/unit/data/_sync/test_autogen.py | 13 ++++--------- 4 files changed, 20 insertions(+), 29 deletions(-) diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index cdb7622b6..fd44fe86c 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -20,10 +20,10 @@ from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync -# from google.cloud.bigtable.data._sync.client import BigtableDataClient -# from google.cloud.bigtable.data._sync.client import Table +from google.cloud.bigtable.data._sync.client import BigtableDataClient +from google.cloud.bigtable.data._sync.client import Table -# from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher +from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.read_rows_query import RowRange @@ -53,9 +53,9 @@ __version__: str = package_version.__version__ __all__ = ( - # "BigtableDataClient", - # "Table", - # "MutationsBatcher", + "BigtableDataClient", + "Table", + "MutationsBatcher", "BigtableDataClientAsync", "TableAsync", "MutationsBatcherAsync", diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 41f248662..d722de008 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -158,13 +158,14 @@ def visit_Attribute(self, node): ): replacement = self.asyncio_replacements[node.attr] return ast.copy_location(ast.parse(replacement, mode="eval").body, node) - if node.attr in self.text_replacements: - # replace from text_replacements - new_node = ast.copy_location( - ast.Attribute(self.visit(node.value), self.text_replacements[node.attr], node.ctx), node - ) - return new_node - return node + fixed = ast.copy_location( + ast.Attribute( + self.visit(node.value), + self.text_replacements.get(node.attr, node.attr), # replace attr value + node.ctx + ), node + ) + return fixed def visit_Name(self, node): node.id = self.text_replacements.get(node.id, node.id) diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index f18f5884b..c0e02cf0e 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -17,6 +17,7 @@ import google.api_core.exceptions as core_exceptions from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.cloud.bigtable.data import TABLE_DEFAULT +from google.cloud.bigtable.data import TableAsync # try/except added for compatibility with python < 3.8 try: @@ -467,10 +468,6 @@ def test_default_argument_consistency(self): table.mutations_batcher. Make sure any changes to defaults are applied to both places """ - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) import inspect get_batcher_signature = dict( @@ -478,7 +475,7 @@ def test_default_argument_consistency(self): ) get_batcher_signature.pop("self") batcher_init_signature = dict( - inspect.signature(MutationsBatcherAsync).parameters + inspect.signature(self._get_target_class()).parameters ) batcher_init_signature.pop("table") # both should have same number of arguments @@ -1161,8 +1158,6 @@ async def test_customizable_retryable_errors( Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - from google.cloud.bigtable.data._async.client import TableAsync - with mock.patch( "google.api_core.retry.if_exception_type" ) as predicate_builder_mock: diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py index e9692a7e7..02bce4a49 100644 --- a/tests/unit/data/_sync/test_autogen.py +++ b/tests/unit/data/_sync/test_autogen.py @@ -23,6 +23,7 @@ import time from google.cloud.bigtable.data import TABLE_DEFAULT +from google.cloud.bigtable.data import Table from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete import google.api_core.exceptions as core_exceptions @@ -201,18 +202,14 @@ def test_default_argument_consistency(self): table.mutations_batcher. Make sure any changes to defaults are applied to both places """ - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) import inspect get_batcher_signature = dict( - inspect.signature(TableAsync.mutations_batcher).parameters + inspect.signature(Table.mutations_batcher).parameters ) get_batcher_signature.pop("self") batcher_init_signature = dict( - inspect.signature(MutationsBatcherAsync).parameters + inspect.signature(self._get_target_class()).parameters ) batcher_init_signature.pop("table") assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) @@ -838,8 +835,6 @@ def test_customizable_retryable_errors(self, input_retryables, expected_retryabl Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - from google.cloud.bigtable.data._async.client import TableAsync - with mock.patch( "google.api_core.retry.if_exception_type" ) as predicate_builder_mock: @@ -848,7 +843,7 @@ def test_customizable_retryable_errors(self, input_retryables, expected_retryabl ) as retry_fn_mock: table = None with mock.patch("asyncio.create_task"): - table = TableAsync(mock.Mock(), "instance", "table") + table = Table(mock.Mock(), "instance", "table") with self._make_one( table, batch_retryable_errors=input_retryables ) as instance: From 06f058606f71f83e1bce3b6e56535ae68e0ac015 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Apr 2024 14:59:36 -0700 Subject: [PATCH 034/360] added event for batcher lock --- .../bigtable/data/_async/mutations_batcher.py | 29 +++++++--- google/cloud/bigtable/data/_sync/_autogen.py | 54 +++++------------- .../bigtable/data/_sync/mutations_batcher.py | 57 ++++++++++++++++++- .../cloud/bigtable/data/_sync/sync_gen.yaml | 4 +- .../data/_async/test_mutations_batcher.py | 24 ++++---- tests/unit/data/_sync/test_autogen.py | 31 +++++----- 6 files changed, 121 insertions(+), 78 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 4dfd0df1a..7e003688d 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -221,7 +221,7 @@ def __init__( batch_retryable_errors, table ) - self.closed: bool = False + self._closed: bool = asyncio.Event() self._table = table self._staged_entries: list[RowMutationEntry] = [] self._staged_count, self._staged_bytes = 0, 0 @@ -260,7 +260,7 @@ def _start_flush_timer(self, interval: float | None) -> asyncio.Future[None]: Returns: - asyncio.Future that represents the background task """ - if interval is None or self.closed: + if interval is None or self._closed.is_set(): empty_future: asyncio.Future[None] = asyncio.Future() empty_future.set_result(None) return empty_future @@ -269,10 +269,13 @@ async def timer_routine(self, interval: float): """ Triggers new flush tasks every `interval` seconds """ - while not self.closed: - await asyncio.sleep(interval) - # add new flush task to list - if not self.closed and self._staged_entries: + while not self._closed.is_set(): + # wait until interval has passed, or until closed + try: + await asyncio.wait_for(self._closed.wait(), timeout=interval) + except asyncio.TimeoutError: + pass + if not self._closed.is_set() and self._staged_entries: self._schedule_flush() return self._create_bg_task(timer_routine, self, interval) @@ -288,7 +291,7 @@ async def append(self, mutation_entry: RowMutationEntry): - RuntimeError if batcher is closed - ValueError if an invalid mutation type is added """ - if self.closed: + if self._closed.is_set(): raise RuntimeError("Cannot append to closed MutationsBatcher") if isinstance(mutation_entry, Mutation): # type: ignore raise ValueError( @@ -419,11 +422,19 @@ async def __aexit__(self, exc_type, exc, tb): """For context manager API""" await self.close() + @property + def closed(self) -> bool: + """ + Returns: + - True if the batcher is closed, False otherwise + """ + return self._closed.is_set() + async def close(self): """ Flush queue and clean up resources """ - self.closed = True + self._closed.set() self._flush_timer.cancel() self._schedule_flush() if self._flush_jobs: @@ -440,7 +451,7 @@ def _on_exit(self): """ Called when program is exited. Raises warning if unflushed mutations remain """ - if not self.closed and self._staged_entries: + if not self._closed.is_set() and self._staged_entries: warnings.warn( f"MutationsBatcher for table {self._table.table_name} was not closed. " f"{len(self._staged_entries)} Unflushed mutations will not be sent to the server." diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 677c95c47..3e3d857c4 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -546,7 +546,7 @@ def __init__( self._retryable_errors: list[type[Exception]] = _get_retryable_errors( batch_retryable_errors, table ) - self.closed: bool = False + self._closed: bool = threading.Event() self._table = table self._staged_entries: list[RowMutationEntry] = [] (self._staged_count, self._staged_bytes) = (0, 0) @@ -575,30 +575,7 @@ def __init__( def _start_flush_timer( self, interval: float | None ) -> concurrent.futures.Future[None]: - """ - Set up a background task to flush the batcher every interval seconds - - If interval is None, an empty future is returned - - Args: - - flush_interval: Automatically flush every flush_interval seconds. - If None, no time-based flushing is performed. - Returns: - - asyncio.Future that represents the background task - """ - if interval is None or self.closed: - empty_future: concurrent.futures.Future[None] = concurrent.futures.Future() - empty_future.set_result(None) - return empty_future - - def timer_routine(self, interval: float): - """Triggers new flush tasks every `interval` seconds""" - while not self.closed: - time.sleep(interval) - if not self.closed and self._staged_entries: - self._schedule_flush() - - return self._create_bg_task(timer_routine, self, interval) + raise NotImplementedError("Function not implemented in sync class") def append(self, mutation_entry: RowMutationEntry): """ @@ -612,7 +589,7 @@ def append(self, mutation_entry: RowMutationEntry): - RuntimeError if batcher is closed - ValueError if an invalid mutation type is added """ - if self.closed: + if self._closed.is_set(): raise RuntimeError("Cannot append to closed MutationsBatcher") if isinstance(mutation_entry, Mutation): raise ValueError( @@ -738,12 +715,20 @@ def __exit__(self, exc_type, exc, tb): """For context manager API""" self.close() + @property + def closed(self) -> bool: + """ + Returns: + - True if the batcher is closed, False otherwise + """ + return self._closed.is_set() + def close(self): - """Flush queue and clean up resources""" + raise NotImplementedError("Function not implemented in sync class") def _on_exit(self): """Called when program is exited. Raises warning if unflushed mutations remain""" - if not self.closed and self._staged_entries: + if not self._closed.is_set() and self._staged_entries: warnings.warn( f"MutationsBatcher for table {self._table.table_name} was not closed. {len(self._staged_entries)} Unflushed mutations will not be sent to the server." ) @@ -757,18 +742,7 @@ def _wait_for_batch_results( *tasks: concurrent.futures.Future[list[FailedMutationEntryError]] | concurrent.futures.Future[None], ) -> list[Exception]: - """ - Takes in a list of futures representing _execute_mutate_rows tasks, - waits for them to complete, and returns a list of errors encountered. - - Args: - - *tasks: futures representing _execute_mutate_rows or _flush_internal tasks - Returns: - - list of Exceptions encountered by any of the tasks. Errors are expected - to be FailedMutationEntryError, representing a failed mutation operation. - If a task fails with a different exception, it will be included in the - output list. Successful tasks will not be represented in the output list. - """ + raise NotImplementedError("Function not implemented in sync class") class _FlowControl_SyncGen(ABC): diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 8a1332f3e..f64b85dc6 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -14,7 +14,8 @@ # from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor +import concurrent.futures +import atexit from google.cloud.bigtable.data._sync._autogen import _FlowControl_SyncGen from google.cloud.bigtable.data._sync._autogen import MutationsBatcher_SyncGen @@ -29,8 +30,60 @@ class _FlowControl(_FlowControl_SyncGen): class MutationsBatcher(MutationsBatcher_SyncGen): def __init__(self, *args, **kwargs): - self._executor = ThreadPoolExecutor(max_workers=8) + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) super().__init__(*args, **kwargs) + def close(self): + """ + Flush queue and clean up resources + """ + self._closed.set() + # attempt cancel timer if not started + self._flush_timer.cancel() + self._schedule_flush() + self._executor.shutdown(wait=True) + atexit.unregister(self._on_exit) + # raise unreported exceptions + self._raise_exceptions() + def _create_bg_task(self, func, *args, **kwargs): return self._executor.submit(func, *args, **kwargs) + + @staticmethod + def _wait_for_batch_results( + *tasks: concurrent.futures.Future[list[FailedMutationEntryError]] + | concurrent.futures.Future[None], + ) -> list[Exception]: + if not tasks: + return [] + exceptions = [] + for task in tasks: + try: + exc_list = task.result() + for exc in exc_list: + # strip index information + exc.index = None + exceptions.extend(exc_list) + except Exception as e: + exceptions.append(e) + return exceptions + + def _start_flush_timer( + self, interval: float | None + ) -> concurrent.futures.Future[None]: + if interval is None or self._closed.is_set(): + empty_future: concurrent.futures.Future[None] = concurrent.futures.Future() + empty_future.set_result(None) + return empty_future + + def timer_routine(self, interval: float): + """ + Triggers new flush tasks every `interval` seconds + """ + while not self._closed.is_set(): + # wait until interval has passed, or until closed + self._closed.wait(interval) + if not self._closed.is_set() and self._staged_entries: + self._schedule_flush() + return self._create_bg_task(timer_routine, self, interval) + diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index a9ebe5db4..203e921ff 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -4,6 +4,7 @@ asyncio_replacements: # Replace asyncio functionaility Condition: threading.Condition Future: concurrent.futures.Future Task: threading.Thread + Event: threading.Event text_replacements: # Find and replace specific text patterns __anext__: __next__ @@ -40,8 +41,7 @@ classes: # Specify transformations for individual classes - path: google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync autogen_sync_name: MutationsBatcher_SyncGen concrete_path: google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher - pass_methods: ["close", "_wait_for_batch_results"] - error_methods: ["_create_bg_task"] + error_methods: ["_create_bg_task", "close", "_wait_for_batch_results", "_start_flush_timer"] - path: google.cloud.bigtable.data._async.mutations_batcher._FlowControlAsync autogen_sync_name: _FlowControl_SyncGen concrete_path: google.cloud.bigtable.data._sync.mutations_batcher._FlowControl diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index c0e02cf0e..4a44c980d 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -15,6 +15,7 @@ import pytest import asyncio import google.api_core.exceptions as core_exceptions +import google.api_core.retry from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data import TableAsync @@ -520,7 +521,7 @@ async def test__flush_timer(self): async with self._make_one(flush_interval=expected_sleep) as instance: instance._staged_entries = [mock.Mock()] loop_num = 3 - with mock.patch("asyncio.sleep") as sleep_mock: + with mock.patch("asyncio.wait_for") as sleep_mock: sleep_mock.side_effect = [None] * loop_num + [ZeroDivisionError("expected")] try: await instance._flush_timer @@ -528,7 +529,8 @@ async def test__flush_timer(self): # replace with np-op so there are no issues on close instance._flush_timer = asyncio.Future() assert sleep_mock.call_count == loop_num + 1 - sleep_mock.assert_called_with(expected_sleep) + sleep_kwargs = sleep_mock.call_args[1] + assert sleep_kwargs["timeout"] == expected_sleep assert flush_mock.call_count == loop_num @pytest.mark.asyncio @@ -538,7 +540,7 @@ async def test__flush_timer_no_mutations(self): expected_sleep = 12 async with self._make_one(flush_interval=expected_sleep) as instance: loop_num = 3 - with mock.patch("asyncio.sleep") as sleep_mock: + with mock.patch("asyncio.wait_for") as sleep_mock: sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] try: await instance._flush_timer @@ -546,13 +548,14 @@ async def test__flush_timer_no_mutations(self): # replace with np-op so there are no issues on close instance._flush_timer = asyncio.Future() assert sleep_mock.call_count == loop_num + 1 - sleep_mock.assert_called_with(expected_sleep) + sleep_kwargs = sleep_mock.call_args[1] + assert sleep_kwargs["timeout"] == expected_sleep assert flush_mock.call_count == 0 @pytest.mark.asyncio async def test__flush_timer_close(self): """Timer should continue terminate after close""" - with mock.patch.object(self._get_target_class(), "_schedule_flush") as flush_mock: + with mock.patch.object(self._get_target_class(), "_schedule_flush"): async with self._make_one() as instance: with mock.patch("asyncio.sleep"): # let task run in background @@ -1028,7 +1031,7 @@ async def test__on_exit(self, recwarn): assert "unflushed mutations" in str(w[0].message).lower() assert str(num_left) in str(w[0].message) # calling while closed is noop - instance.closed = True + instance._closed.set() instance._on_exit() assert len(recwarn) == 0 # reset staged mutations for cleanup @@ -1158,12 +1161,9 @@ async def test_customizable_retryable_errors( Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - with mock.patch( - "google.api_core.retry.if_exception_type" - ) as predicate_builder_mock: - with mock.patch( - "google.api_core.retry.retry_target_async" - ) as retry_fn_mock: + retryn_fn = "retry_target_async" if "Async" in self._get_target_class().__name__ else "retry_target" + with mock.patch.object(google.api_core.retry, "if_exception_type") as predicate_builder_mock: + with mock.patch.object(google.api_core.retry, retryn_fn) as retry_fn_mock: table = None with mock.patch("asyncio.create_task"): table = TableAsync(mock.Mock(), "instance", "table") diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py index 02bce4a49..5d82121f0 100644 --- a/tests/unit/data/_sync/test_autogen.py +++ b/tests/unit/data/_sync/test_autogen.py @@ -25,7 +25,9 @@ from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data import Table from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +import google.api_core.exceptions import google.api_core.exceptions as core_exceptions +import google.api_core.retry class TestMutationsBatcher(ABC): @@ -254,7 +256,7 @@ def test__flush_timer(self): with self._make_one(flush_interval=expected_sleep) as instance: instance._staged_entries = [mock.Mock()] loop_num = 3 - with mock.patch("asyncio.sleep") as sleep_mock: + with mock.patch("asyncio.wait_for") as sleep_mock: sleep_mock.side_effect = [None] * loop_num + [ ZeroDivisionError("expected") ] @@ -263,7 +265,8 @@ def test__flush_timer(self): except ZeroDivisionError: instance._flush_timer = concurrent.futures.Future() assert sleep_mock.call_count == loop_num + 1 - sleep_mock.assert_called_with(expected_sleep) + sleep_kwargs = sleep_mock.call_args[1] + assert sleep_kwargs["timeout"] == expected_sleep assert flush_mock.call_count == loop_num def test__flush_timer_no_mutations(self): @@ -274,21 +277,20 @@ def test__flush_timer_no_mutations(self): expected_sleep = 12 with self._make_one(flush_interval=expected_sleep) as instance: loop_num = 3 - with mock.patch("asyncio.sleep") as sleep_mock: + with mock.patch("asyncio.wait_for") as sleep_mock: sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] try: instance._flush_timer except TabError: instance._flush_timer = concurrent.futures.Future() assert sleep_mock.call_count == loop_num + 1 - sleep_mock.assert_called_with(expected_sleep) + sleep_kwargs = sleep_mock.call_args[1] + assert sleep_kwargs["timeout"] == expected_sleep assert flush_mock.call_count == 0 def test__flush_timer_close(self): """Timer should continue terminate after close""" - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: + with mock.patch.object(self._get_target_class(), "_schedule_flush"): with self._make_one() as instance: with mock.patch("asyncio.sleep"): time.sleep(0.5) @@ -717,7 +719,7 @@ def test__on_exit(self, recwarn): assert len(w) == 1 assert "unflushed mutations" in str(w[0].message).lower() assert str(num_left) in str(w[0].message) - instance.closed = True + instance._closed.set() instance._on_exit() assert len(recwarn) == 0 instance._staged_entries = [] @@ -835,12 +837,15 @@ def test_customizable_retryable_errors(self, input_retryables, expected_retryabl Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - with mock.patch( - "google.api_core.retry.if_exception_type" + retryn_fn = ( + "retry_target_async" + if "Async" in self._get_target_class().__name__ + else "retry_target" + ) + with mock.patch.object( + google.api_core.retry, "if_exception_type" ) as predicate_builder_mock: - with mock.patch( - "google.api_core.retry.retry_target_async" - ) as retry_fn_mock: + with mock.patch.object(google.api_core.retry, retryn_fn) as retry_fn_mock: table = None with mock.patch("asyncio.create_task"): table = Table(mock.Mock(), "instance", "table") From d097aa15c8846665fb681093b491b8a4cc352ea0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Apr 2024 16:55:02 -0700 Subject: [PATCH 035/360] got mutation batcher tests passing --- .../bigtable/data/_async/mutations_batcher.py | 50 +++---- google/cloud/bigtable/data/_sync/_autogen.py | 11 +- .../bigtable/data/_sync/mutations_batcher.py | 47 ++++--- .../cloud/bigtable/data/_sync/sync_gen.yaml | 2 +- .../data/_async/test_mutations_batcher.py | 131 ++++++++++-------- tests/unit/data/_sync/test_autogen.py | 130 ++++++++++------- 6 files changed, 203 insertions(+), 168 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 7e003688d..6d1fd8438 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -234,7 +234,7 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - self._flush_timer = self._start_flush_timer(flush_interval) + self._flush_timer = self._create_bg_task(self._timer_routine, flush_interval) self._flush_jobs: set[asyncio.Future[None]] = set() # MutationExceptionGroup reports number of successful entries along with failures self._entries_processed_since_last_raise: int = 0 @@ -248,36 +248,21 @@ def __init__( # clean up on program exit atexit.register(self._on_exit) - def _start_flush_timer(self, interval: float | None) -> asyncio.Future[None]: + async def _timer_routine(self, interval: float | None) -> None: """ - Set up a background task to flush the batcher every interval seconds - - If interval is None, an empty future is returned - - Args: - - flush_interval: Automatically flush every flush_interval seconds. - If None, no time-based flushing is performed. - Returns: - - asyncio.Future that represents the background task - """ - if interval is None or self._closed.is_set(): - empty_future: asyncio.Future[None] = asyncio.Future() - empty_future.set_result(None) - return empty_future - - async def timer_routine(self, interval: float): - """ - Triggers new flush tasks every `interval` seconds - """ - while not self._closed.is_set(): - # wait until interval has passed, or until closed - try: - await asyncio.wait_for(self._closed.wait(), timeout=interval) - except asyncio.TimeoutError: - pass - if not self._closed.is_set() and self._staged_entries: - self._schedule_flush() - return self._create_bg_task(timer_routine, self, interval) + Triggers new flush tasks every `interval` seconds + Ends when the batcher is closed + """ + if not interval or interval <= 0: + return None + while not self._closed.is_set(): + # wait until interval has passed, or until closed + try: + await asyncio.wait_for(self._closed.wait(), timeout=interval) + except asyncio.TimeoutError: + pass + if not self._closed.is_set() and self._staged_entries: + self._schedule_flush() async def append(self, mutation_entry: RowMutationEntry): """ @@ -315,8 +300,9 @@ def _schedule_flush(self) -> asyncio.Future[None] | None: entries, self._staged_entries = self._staged_entries, [] self._staged_count, self._staged_bytes = 0, 0 new_task = self._create_bg_task(self._flush_internal, entries) - new_task.add_done_callback(self._flush_jobs.remove) - self._flush_jobs.add(new_task) + if not new_task.done(): + self._flush_jobs.add(new_task) + new_task.add_done_callback(self._flush_jobs.remove) return new_task return None diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 3e3d857c4..6265d315f 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -561,7 +561,7 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - self._flush_timer = self._start_flush_timer(flush_interval) + self._flush_timer = self._create_bg_task(self._timer_routine, flush_interval) self._flush_jobs: set[concurrent.futures.Future[None]] = set() self._entries_processed_since_last_raise: int = 0 self._exceptions_since_last_raise: int = 0 @@ -572,9 +572,7 @@ def __init__( ) atexit.register(self._on_exit) - def _start_flush_timer( - self, interval: float | None - ) -> concurrent.futures.Future[None]: + def _timer_routine(self, interval: float | None) -> None: raise NotImplementedError("Function not implemented in sync class") def append(self, mutation_entry: RowMutationEntry): @@ -611,8 +609,9 @@ def _schedule_flush(self) -> concurrent.futures.Future[None] | None: (entries, self._staged_entries) = (self._staged_entries, []) (self._staged_count, self._staged_bytes) = (0, 0) new_task = self._create_bg_task(self._flush_internal, entries) - new_task.add_done_callback(self._flush_jobs.remove) - self._flush_jobs.add(new_task) + if not new_task.done(): + self._flush_jobs.add(new_task) + new_task.add_done_callback(self._flush_jobs.remove) return new_task return None diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index f64b85dc6..7bd93873a 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -14,6 +14,8 @@ # from __future__ import annotations +from typing import TYPE_CHECKING + import concurrent.futures import atexit @@ -23,15 +25,24 @@ # import required so MutationsBatcher_SyncGen can create _MutateRowsOperation import google.cloud.bigtable.data._sync._mutate_rows # noqa: F401 +if TYPE_CHECKING: + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + class _FlowControl(_FlowControl_SyncGen): pass class MutationsBatcher(MutationsBatcher_SyncGen): - def __init__(self, *args, **kwargs): - self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) - super().__init__(*args, **kwargs) + @property + def _executor(self): + """ + Return a ThreadPoolExecutor for background tasks + """ + if not hasattr(self, "_threadpool"): + self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=8) + return self._threadpool def close(self): """ @@ -68,22 +79,16 @@ def _wait_for_batch_results( exceptions.append(e) return exceptions - def _start_flush_timer( - self, interval: float | None - ) -> concurrent.futures.Future[None]: - if interval is None or self._closed.is_set(): - empty_future: concurrent.futures.Future[None] = concurrent.futures.Future() - empty_future.set_result(None) - return empty_future - - def timer_routine(self, interval: float): - """ - Triggers new flush tasks every `interval` seconds - """ - while not self._closed.is_set(): - # wait until interval has passed, or until closed - self._closed.wait(interval) - if not self._closed.is_set() and self._staged_entries: - self._schedule_flush() - return self._create_bg_task(timer_routine, self, interval) + def _timer_routine(self, interval: float | None) -> None: + """ + Triggers new flush tasks every `interval` seconds + Ends when the batcher is closed + """ + if not interval or interval <= 0: + return None + while not self._closed.is_set(): + # wait until interval has passed, or until closed + self._closed.wait(timeout=interval) + if not self._closed.is_set() and self._staged_entries: + self._schedule_flush() diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index 203e921ff..a34ae832c 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -41,7 +41,7 @@ classes: # Specify transformations for individual classes - path: google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync autogen_sync_name: MutationsBatcher_SyncGen concrete_path: google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher - error_methods: ["_create_bg_task", "close", "_wait_for_batch_results", "_start_flush_timer"] + error_methods: ["_create_bg_task", "close", "_wait_for_batch_results", "_timer_routine"] - path: google.cloud.bigtable.data._async.mutations_batcher._FlowControlAsync autogen_sync_name: _FlowControl_SyncGen concrete_path: google.cloud.bigtable.data._sync.mutations_batcher._FlowControl diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 4a44c980d..adb92b63f 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -14,6 +14,7 @@ import pytest import asyncio +import time import google.api_core.exceptions as core_exceptions import google.api_core.retry from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete @@ -309,11 +310,10 @@ def _get_target_class(self): return MutationsBatcherAsync - def _get_mutate_rows_class_path(self): - # location the MutateRowsOperation class is imported for mocking - # will be different in sync vs async versions - - return "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync" + @staticmethod + def is_async(): + # helepr function for changing tests between sync and async versions + return True def _make_one(self, table=None, **kwargs): from google.api_core.exceptions import DeadlineExceeded @@ -339,7 +339,7 @@ def _make_mutation(count=1, size=1): @pytest.mark.asyncio async def test_ctor_defaults(self): - with mock.patch.object(self._get_target_class(), "_start_flush_timer", return_value=asyncio.Future()) as flush_timer_mock: + with mock.patch.object(self._get_target_class(), "_timer_routine", return_value=asyncio.Future()) as flush_timer_mock: table = mock.Mock() table.default_mutate_rows_operation_timeout = 10 table.default_mutate_rows_attempt_timeout = 8 @@ -376,7 +376,7 @@ async def test_ctor_defaults(self): @pytest.mark.asyncio async def test_ctor_explicit(self): """Test with explicit parameters""" - with mock.patch.object(self._get_target_class(), "_start_flush_timer", return_value=asyncio.Future()) as flush_timer_mock: + with mock.patch.object(self._get_target_class(), "_timer_routine", return_value=asyncio.Future()) as flush_timer_mock: table = mock.Mock() flush_interval = 20 flush_limit_count = 17 @@ -424,7 +424,7 @@ async def test_ctor_explicit(self): @pytest.mark.asyncio async def test_ctor_no_flush_limits(self): """Test with None for flush limits""" - with mock.patch.object(self._get_target_class(), "_start_flush_timer", return_value=asyncio.Future()) as flush_timer_mock: + with mock.patch.object(self._get_target_class(), "_timer_routine", return_value=asyncio.Future()) as flush_timer_mock: table = mock.Mock() table.default_mutate_rows_operation_timeout = 10 table.default_mutate_rows_attempt_timeout = 8 @@ -492,65 +492,65 @@ def test_default_argument_consistency(self): ) @pytest.mark.asyncio - async def test__start_flush_timer_w_None(self): - """Empty timer should return immediately""" + @pytest.mark.parametrize("input_val", [None, 0, -1]) + async def test__start_flush_timer_w_empty_input(self, input_val): + """Empty/invalid timer should return immediately""" with mock.patch.object(self._get_target_class(), "_schedule_flush") as flush_mock: + # mock different method depending on sync vs async async with self._make_one() as instance: - with mock.patch("asyncio.sleep") as sleep_mock: - await instance._start_flush_timer(None) + if self.is_async(): + sleep_obj, sleep_method = asyncio, "wait_for" + else: + sleep_obj, sleep_method = instance._closed, "wait" + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + result = await instance._timer_routine(input_val) assert sleep_mock.call_count == 0 assert flush_mock.call_count == 0 + assert result is None @pytest.mark.asyncio + @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test__start_flush_timer_call_when_closed(self,): """closed batcher's timer should return immediately""" with mock.patch.object(self._get_target_class(), "_schedule_flush") as flush_mock: async with self._make_one() as instance: await instance.close() flush_mock.reset_mock() - with mock.patch("asyncio.sleep") as sleep_mock: - await instance._start_flush_timer(1) + # mock different method depending on sync vs async + if self.is_async(): + sleep_obj, sleep_method = asyncio, "wait_for" + else: + sleep_obj, sleep_method = instance._closed, "wait" + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + await instance._timer_routine(10) assert sleep_mock.call_count == 0 assert flush_mock.call_count == 0 @pytest.mark.asyncio - async def test__flush_timer(self): + @pytest.mark.parametrize("num_staged", [0, 1, 10]) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + async def test__flush_timer(self, num_staged): """Timer should continue to call _schedule_flush in a loop""" with mock.patch.object(self._get_target_class(), "_schedule_flush") as flush_mock: expected_sleep = 12 async with self._make_one(flush_interval=expected_sleep) as instance: - instance._staged_entries = [mock.Mock()] loop_num = 3 - with mock.patch("asyncio.wait_for") as sleep_mock: - sleep_mock.side_effect = [None] * loop_num + [ZeroDivisionError("expected")] - try: - await instance._flush_timer - except ZeroDivisionError: - # replace with np-op so there are no issues on close - instance._flush_timer = asyncio.Future() - assert sleep_mock.call_count == loop_num + 1 - sleep_kwargs = sleep_mock.call_args[1] - assert sleep_kwargs["timeout"] == expected_sleep - assert flush_mock.call_count == loop_num - - @pytest.mark.asyncio - async def test__flush_timer_no_mutations(self): - """Timer should not flush if no new mutations have been staged""" - with mock.patch.object(self._get_target_class(), "_schedule_flush") as flush_mock: - expected_sleep = 12 - async with self._make_one(flush_interval=expected_sleep) as instance: - loop_num = 3 - with mock.patch("asyncio.wait_for") as sleep_mock: + instance._staged_entries = [mock.Mock()] * num_staged + # mock different method depending on sync vs async + if self.is_async(): + sleep_obj, sleep_method = asyncio, "wait_for" + else: + sleep_obj, sleep_method = instance._closed, "wait" + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] - try: - await instance._flush_timer - except TabError: - # replace with np-op so there are no issues on close - instance._flush_timer = asyncio.Future() + with pytest.raises(TabError): + await self._get_target_class()._timer_routine(instance, expected_sleep) + # replace with np-op so there are no issues on close + instance._flush_timer = asyncio.Future() assert sleep_mock.call_count == loop_num + 1 sleep_kwargs = sleep_mock.call_args[1] assert sleep_kwargs["timeout"] == expected_sleep - assert flush_mock.call_count == 0 + assert flush_mock.call_count == (0 if num_staged == 0 else loop_num) @pytest.mark.asyncio async def test__flush_timer_close(self): @@ -570,9 +570,9 @@ async def test__flush_timer_close(self): @pytest.mark.asyncio async def test_append_closed(self): """Should raise exception""" + instance = self._make_one() + await instance.close() with pytest.raises(RuntimeError): - instance = self._make_one() - await instance.close() await instance.append(mock.Mock()) @pytest.mark.asyncio @@ -633,7 +633,7 @@ async def mock_call(*args, **kwargs): for _ in range(num_entries): await instance.append(self._make_mutation(size=1)) # let any flush jobs finish - await asyncio.gather(*instance._flush_jobs) + await instance._wait_for_batch_results(*instance._flush_jobs) # should have only flushed once, with large mutation and first mutation in loop assert op_mock.call_count == 1 sent_batch = op_mock.call_args[0][0] @@ -654,6 +654,7 @@ async def mock_call(*args, **kwargs): ], ) @pytest.mark.asyncio + @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_append( self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush ): @@ -732,7 +733,7 @@ async def mock_call(*args, **kwargs): ) await asyncio.sleep(0.01) # allow flushes to complete - await asyncio.gather(*instance._flush_jobs) + await instance._wait_for_batch_results(*instance._flush_jobs) duration = time.monotonic() - start_time assert len(instance._oldest_exceptions) == 0 assert len(instance._newest_exceptions) == 0 @@ -750,10 +751,14 @@ async def test_schedule_flush_no_mutations(self): assert flush_mock.call_count == 0 @pytest.mark.asyncio + @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_schedule_flush_with_mutations(self): """if new mutations exist, should add a new flush task to _flush_jobs""" async with self._make_one() as instance: with mock.patch.object(instance, "_flush_internal") as flush_mock: + if not self.is_async(): + # simulate operation + flush_mock.side_effect = lambda x: time.sleep(0.1) for i in range(1, 4): mutation = mock.Mock() instance._staged_entries = [mutation] @@ -764,7 +769,8 @@ async def test_schedule_flush_with_mutations(self): assert instance._staged_entries == [] assert instance._staged_count == 0 assert instance._staged_bytes == 0 - assert flush_mock.call_count == i + assert flush_mock.call_count == 1 + flush_mock.reset_mock() @pytest.mark.asyncio async def test__flush_internal(self): @@ -801,13 +807,19 @@ async def test_flush_clears_job_list(self): and removed when it completes """ async with self._make_one() as instance: - with mock.patch.object(instance, "_flush_internal", AsyncMock()): + with mock.patch.object(instance, "_flush_internal", AsyncMock()) as flush_mock: + if not self.is_async(): + # simulate operation + flush_mock.side_effect = lambda x: time.sleep(0.1) mutations = [self._make_mutation(count=1, size=1)] instance._staged_entries = mutations assert instance._flush_jobs == set() new_job = instance._schedule_flush() assert instance._flush_jobs == {new_job} - await new_job + if self.is_async(): + await new_job + else: + new_job.result() assert instance._flush_jobs == set() @pytest.mark.parametrize( @@ -901,8 +913,11 @@ async def test_timer_flush_end_to_end(self): @pytest.mark.asyncio async def test__execute_mutate_rows(self): - mutate_path = self._get_mutate_rows_class_path() - with mock.patch(mutate_path) as mutate_rows: + if self.is_async(): + mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" + else: + mutate_path = "_sync._mutate_rows._MutateRowsOperation" + with mock.patch(f"google.cloud.bigtable.data.{mutate_path}") as mutate_rows: mutate_rows.return_value = AsyncMock() start_operation = mutate_rows().start table = mock.Mock() @@ -930,8 +945,11 @@ async def test__execute_mutate_rows_returns_errors(self): MutationsExceptionGroup, FailedMutationEntryError, ) - cls_path = self._get_mutate_rows_class_path() - with mock.patch(f"{cls_path}.start") as mutate_rows: + if self.is_async(): + mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" + else: + mutate_path = "_sync._mutate_rows._MutateRowsOperation" + with mock.patch(f"google.cloud.bigtable.data.{mutate_path}.start") as mutate_rows: err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) @@ -1053,8 +1071,11 @@ async def test_timeout_args_passed(self): batch_operation_timeout and batch_attempt_timeout should be used in api calls """ - mutate_path = self._get_mutate_rows_class_path() - with mock.patch(mutate_path, return_value=AsyncMock()) as mutate_rows: + if self.is_async(): + mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" + else: + mutate_path = "_sync._mutate_rows._MutateRowsOperation" + with mock.patch(f"google.cloud.bigtable.data.{mutate_path}", return_value=AsyncMock()) as mutate_rows: expected_operation_timeout = 17 expected_attempt_timeout = 13 async with self._make_one( diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py index 5d82121f0..14b4d0f5e 100644 --- a/tests/unit/data/_sync/test_autogen.py +++ b/tests/unit/data/_sync/test_autogen.py @@ -17,6 +17,7 @@ from abc import ABC from unittest import mock +import asyncio import concurrent.futures import mock import pytest @@ -36,8 +37,9 @@ def _get_target_class(self): return MutationsBatcher - def _get_mutate_rows_class_path(self): - return "google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation" + @staticmethod + def is_async(): + return False def _make_one(self, table=None, **kwargs): from google.api_core.exceptions import DeadlineExceeded @@ -63,7 +65,7 @@ def _make_mutation(count=1, size=1): def test_ctor_defaults(self): with mock.patch.object( self._get_target_class(), - "_start_flush_timer", + "_timer_routine", return_value=concurrent.futures.Future(), ) as flush_timer_mock: table = mock.Mock() @@ -105,7 +107,7 @@ def test_ctor_explicit(self): """Test with explicit parameters""" with mock.patch.object( self._get_target_class(), - "_start_flush_timer", + "_timer_routine", return_value=concurrent.futures.Future(), ) as flush_timer_mock: table = mock.Mock() @@ -158,7 +160,7 @@ def test_ctor_no_flush_limits(self): """Test with None for flush limits""" with mock.patch.object( self._get_target_class(), - "_start_flush_timer", + "_timer_routine", return_value=concurrent.futures.Future(), ) as flush_timer_mock: table = mock.Mock() @@ -223,17 +225,24 @@ def test_default_argument_consistency(self): == batcher_init_signature[arg_name].default ) - def test__start_flush_timer_w_None(self): - """Empty timer should return immediately""" + @pytest.mark.parametrize("input_val", [None, 0, -1]) + def test__start_flush_timer_w_empty_input(self, input_val): + """Empty/invalid timer should return immediately""" with mock.patch.object( self._get_target_class(), "_schedule_flush" ) as flush_mock: with self._make_one() as instance: - with mock.patch("asyncio.sleep") as sleep_mock: - instance._start_flush_timer(None) + if self.is_async(): + (sleep_obj, sleep_method) = (asyncio, "wait_for") + else: + (sleep_obj, sleep_method) = (instance._closed, "wait") + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + result = instance._timer_routine(input_val) assert sleep_mock.call_count == 0 assert flush_mock.call_count == 0 + assert result is None + @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test__start_flush_timer_call_when_closed(self): """closed batcher's timer should return immediately""" with mock.patch.object( @@ -242,51 +251,41 @@ def test__start_flush_timer_call_when_closed(self): with self._make_one() as instance: instance.close() flush_mock.reset_mock() - with mock.patch("asyncio.sleep") as sleep_mock: - instance._start_flush_timer(1) + if self.is_async(): + (sleep_obj, sleep_method) = (asyncio, "wait_for") + else: + (sleep_obj, sleep_method) = (instance._closed, "wait") + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + instance._timer_routine(10) assert sleep_mock.call_count == 0 assert flush_mock.call_count == 0 - def test__flush_timer(self): + @pytest.mark.parametrize("num_staged", [0, 1, 10]) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test__flush_timer(self, num_staged): """Timer should continue to call _schedule_flush in a loop""" - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: - expected_sleep = 12 - with self._make_one(flush_interval=expected_sleep) as instance: - instance._staged_entries = [mock.Mock()] - loop_num = 3 - with mock.patch("asyncio.wait_for") as sleep_mock: - sleep_mock.side_effect = [None] * loop_num + [ - ZeroDivisionError("expected") - ] - try: - instance._flush_timer - except ZeroDivisionError: - instance._flush_timer = concurrent.futures.Future() - assert sleep_mock.call_count == loop_num + 1 - sleep_kwargs = sleep_mock.call_args[1] - assert sleep_kwargs["timeout"] == expected_sleep - assert flush_mock.call_count == loop_num - - def test__flush_timer_no_mutations(self): - """Timer should not flush if no new mutations have been staged""" with mock.patch.object( self._get_target_class(), "_schedule_flush" ) as flush_mock: expected_sleep = 12 with self._make_one(flush_interval=expected_sleep) as instance: loop_num = 3 - with mock.patch("asyncio.wait_for") as sleep_mock: + instance._staged_entries = [mock.Mock()] * num_staged + if self.is_async(): + (sleep_obj, sleep_method) = (asyncio, "wait_for") + else: + (sleep_obj, sleep_method) = (instance._closed, "wait") + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] - try: - instance._flush_timer - except TabError: - instance._flush_timer = concurrent.futures.Future() + with pytest.raises(TabError): + self._get_target_class()._timer_routine( + instance, expected_sleep + ) + instance._flush_timer = concurrent.futures.Future() assert sleep_mock.call_count == loop_num + 1 sleep_kwargs = sleep_mock.call_args[1] assert sleep_kwargs["timeout"] == expected_sleep - assert flush_mock.call_count == 0 + assert flush_mock.call_count == (0 if num_staged == 0 else loop_num) def test__flush_timer_close(self): """Timer should continue terminate after close""" @@ -301,9 +300,9 @@ def test__flush_timer_close(self): def test_append_closed(self): """Should raise exception""" + instance = self._make_one() + instance.close() with pytest.raises(RuntimeError): - instance = self._make_one() - instance.close() instance.append(mock.Mock()) def test_append_wrong_mutation(self): @@ -358,7 +357,7 @@ def mock_call(*args, **kwargs): num_entries = 10 for _ in range(num_entries): instance.append(self._make_mutation(size=1)) - print(*instance._flush_jobs) + instance._wait_for_batch_results(*instance._flush_jobs) assert op_mock.call_count == 1 sent_batch = op_mock.call_args[0][0] assert len(sent_batch) == 2 @@ -376,6 +375,7 @@ def mock_call(*args, **kwargs): (1, 1, 0, 0, False), ], ) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_append( self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush ): @@ -447,7 +447,7 @@ def mock_call(*args, **kwargs): [self._make_mutation(count=1)] ) time.sleep(0.01) - print(*instance._flush_jobs) + instance._wait_for_batch_results(*instance._flush_jobs) duration = time.monotonic() - start_time assert len(instance._oldest_exceptions) == 0 assert len(instance._newest_exceptions) == 0 @@ -462,10 +462,13 @@ def test_schedule_flush_no_mutations(self): assert instance._schedule_flush() is None assert flush_mock.call_count == 0 + @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_schedule_flush_with_mutations(self): """if new mutations exist, should add a new flush task to _flush_jobs""" with self._make_one() as instance: with mock.patch.object(instance, "_flush_internal") as flush_mock: + if not self.is_async(): + flush_mock.side_effect = lambda x: time.sleep(0.1) for i in range(1, 4): mutation = mock.Mock() instance._staged_entries = [mutation] @@ -475,7 +478,8 @@ def test_schedule_flush_with_mutations(self): assert instance._staged_entries == [] assert instance._staged_count == 0 assert instance._staged_bytes == 0 - assert flush_mock.call_count == i + assert flush_mock.call_count == 1 + flush_mock.reset_mock() def test__flush_internal(self): """ @@ -510,13 +514,20 @@ def test_flush_clears_job_list(self): and removed when it completes """ with self._make_one() as instance: - with mock.patch.object(instance, "_flush_internal", mock.Mock()): + with mock.patch.object( + instance, "_flush_internal", mock.Mock() + ) as flush_mock: + if not self.is_async(): + flush_mock.side_effect = lambda x: time.sleep(0.1) mutations = [self._make_mutation(count=1, size=1)] instance._staged_entries = mutations assert instance._flush_jobs == set() new_job = instance._schedule_flush() assert instance._flush_jobs == {new_job} - new_job + if self.is_async(): + new_job + else: + new_job.result() assert instance._flush_jobs == set() @pytest.mark.parametrize( @@ -601,8 +612,11 @@ def test_timer_flush_end_to_end(self): assert instance._entries_processed_since_last_raise == num_nutations def test__execute_mutate_rows(self): - mutate_path = self._get_mutate_rows_class_path() - with mock.patch(mutate_path) as mutate_rows: + if self.is_async(): + mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" + else: + mutate_path = "_sync._mutate_rows._MutateRowsOperation" + with mock.patch(f"google.cloud.bigtable.data.{mutate_path}") as mutate_rows: mutate_rows.return_value = mock.Mock() start_operation = mutate_rows().start table = mock.Mock() @@ -630,8 +644,13 @@ def test__execute_mutate_rows_returns_errors(self): FailedMutationEntryError, ) - cls_path = self._get_mutate_rows_class_path() - with mock.patch(f"{cls_path}.start") as mutate_rows: + if self.is_async(): + mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" + else: + mutate_path = "_sync._mutate_rows._MutateRowsOperation" + with mock.patch( + f"google.cloud.bigtable.data.{mutate_path}.start" + ) as mutate_rows: err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) @@ -738,8 +757,13 @@ def test_timeout_args_passed(self): batch_operation_timeout and batch_attempt_timeout should be used in api calls """ - mutate_path = self._get_mutate_rows_class_path() - with mock.patch(mutate_path, return_value=mock.Mock()) as mutate_rows: + if self.is_async(): + mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" + else: + mutate_path = "_sync._mutate_rows._MutateRowsOperation" + with mock.patch( + f"google.cloud.bigtable.data.{mutate_path}", return_value=mock.Mock() + ) as mutate_rows: expected_operation_timeout = 17 expected_attempt_timeout = 13 with self._make_one( From 8df5920947f1b7a81897f8f0d08d138ee3de8e3e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Sat, 13 Apr 2024 09:31:22 -0700 Subject: [PATCH 036/360] did some refactoring of client tests --- google/cloud/bigtable/data/_sync/client.py | 1 + tests/unit/data/_async/test_client.py | 221 ++++++++++++--------- 2 files changed, 127 insertions(+), 95 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 0bb6b9586..54d143cef 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -33,6 +33,7 @@ def __init__( client_options: dict[str, Any] | "google.api_core.client_options.ClientOptions" | None = None, + **kwargs ): # remove pool size option in sync client super().__init__( diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index a0019947d..065b00b07 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -45,38 +45,38 @@ ) -def _make_client(*args, use_emulator=True, **kwargs): - import os - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - env_mask = {} - # by default, use emulator mode to avoid auth issues in CI - # emulator mode must be disabled by tests that check channel pooling/refresh background tasks - if use_emulator: - env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" - else: - # set some default values - kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) - kwargs["project"] = kwargs.get("project", "project-id") - with mock.patch.dict(os.environ, env_mask): - return BigtableDataClientAsync(*args, **kwargs) - - class TestBigtableDataClientAsync: - def _get_target_class(self): + + @staticmethod + def _get_target_class(): from google.cloud.bigtable.data._async.client import BigtableDataClientAsync return BigtableDataClientAsync - def _make_one(self, *args, **kwargs): - return _make_client(*args, **kwargs) + @classmethod + def _make_client(cls, *args, use_emulator=True, **kwargs): + import os + + env_mask = {} + # by default, use emulator mode to avoid auth issues in CI + # emulator mode must be disabled by tests that check channel pooling/refresh background tasks + if use_emulator: + env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" + import warnings + warnings.filterwarnings('ignore', category=RuntimeWarning) + else: + # set some default values + kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) + kwargs["project"] = kwargs.get("project", "project-id") + with mock.patch.dict(os.environ, env_mask): + return cls._get_target_class()(*args, **kwargs) @pytest.mark.asyncio async def test_ctor(self): expected_project = "project-id" expected_pool_size = 11 expected_credentials = AnonymousCredentials() - client = self._make_one( + client = self._make_client( project="project-id", pool_size=expected_pool_size, credentials=expected_credentials, @@ -111,7 +111,7 @@ async def test_ctor_super_inits(self): ) as client_project_init: client_project_init.return_value = None try: - self._make_one( + self._make_client( project=project, pool_size=pool_size, credentials=credentials, @@ -143,7 +143,7 @@ async def test_ctor_dict_options(self): client_options = {"api_endpoint": "foo.bar:1234"} with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: try: - self._make_one(client_options=client_options) + self._make_client(client_options=client_options) except TypeError: pass bigtable_client_init.assert_called_once() @@ -154,7 +154,7 @@ async def test_ctor_dict_options(self): with mock.patch.object( self._get_target_class(), "_start_background_channel_refresh" ) as start_background_refresh: - client = self._make_one(client_options=client_options, use_emulator=False) + client = self._make_client(client_options=client_options, use_emulator=False) start_background_refresh.assert_called_once() await client.close() @@ -164,7 +164,7 @@ async def test_veneer_grpc_headers(self): # detect as a veneer client patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") with patch as gapic_mock: - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") wrapped_call_list = gapic_mock.call_args_list assert len(wrapped_call_list) > 0 # each wrapped call should have veneer headers @@ -186,11 +186,11 @@ async def test_channel_pool_creation(self): "google.api_core.grpc_helpers_async.create_channel" ) as create_channel: create_channel.return_value = AsyncMock() - client = self._make_one(project="project-id", pool_size=pool_size) + client = self._make_client(project="project-id", pool_size=pool_size) assert create_channel.call_count == pool_size await client.close() # channels should be unique - client = self._make_one(project="project-id", pool_size=pool_size) + client = self._make_client(project="project-id", pool_size=pool_size) pool_list = list(client.transport._grpc_channel._pool) pool_set = set(client.transport._grpc_channel._pool) assert len(pool_list) == len(pool_set) @@ -205,7 +205,7 @@ async def test_channel_pool_rotation(self): pool_size = 7 with mock.patch.object(PooledChannel, "next_channel") as next_channel: - client = self._make_one(project="project-id", pool_size=pool_size) + client = self._make_client(project="project-id", pool_size=pool_size) assert len(client.transport._grpc_channel._pool) == pool_size next_channel.reset_mock() with mock.patch.object( @@ -228,7 +228,7 @@ async def test_channel_pool_rotation(self): async def test_channel_pool_replace(self): with mock.patch.object(asyncio, "sleep"): pool_size = 7 - client = self._make_one(project="project-id", pool_size=pool_size) + client = self._make_client(project="project-id", pool_size=pool_size) for replace_idx in range(pool_size): start_pool = [ channel for channel in client.transport._grpc_channel._pool @@ -254,14 +254,14 @@ async def test_channel_pool_replace(self): @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test__start_background_channel_refresh_sync(self): # should raise RuntimeError if called in a sync context - client = self._make_one(project="project-id", use_emulator=False) + client = self._make_client(project="project-id", use_emulator=False) with pytest.raises(RuntimeError): client._start_background_channel_refresh() @pytest.mark.asyncio async def test__start_background_channel_refresh_tasks_exist(self): # if tasks exist, should do nothing - client = self._make_one(project="project-id", use_emulator=False) + client = self._make_client(project="project-id", use_emulator=False) assert len(client._channel_refresh_tasks) > 0 with mock.patch.object(asyncio, "create_task") as create_task: client._start_background_channel_refresh() @@ -272,7 +272,7 @@ async def test__start_background_channel_refresh_tasks_exist(self): @pytest.mark.parametrize("pool_size", [1, 3, 7]) async def test__start_background_channel_refresh(self, pool_size): # should create background tasks for each channel - client = self._make_one( + client = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False ) ping_and_warm = AsyncMock() @@ -294,7 +294,7 @@ async def test__start_background_channel_refresh(self, pool_size): async def test__start_background_channel_refresh_tasks_names(self): # if tasks exist, should do nothing pool_size = 3 - client = self._make_one( + client = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False ) for i in range(pool_size): @@ -410,7 +410,7 @@ async def test__manage_channel_first_sleep( with mock.patch.object(asyncio, "sleep") as sleep: sleep.side_effect = asyncio.CancelledError try: - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") client._channel_init_time = -wait_time await client._manage_channel(0, refresh_interval, refresh_interval) except asyncio.CancelledError: @@ -492,7 +492,7 @@ async def test__manage_channel_sleeps( asyncio.CancelledError ] try: - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") if refresh_interval is not None: await client._manage_channel( channel_idx, refresh_interval, refresh_interval @@ -517,7 +517,7 @@ async def test__manage_channel_random(self): uniform.return_value = 0 try: uniform.side_effect = asyncio.CancelledError - client = self._make_one(project="project-id", pool_size=1) + client = self._make_client(project="project-id", pool_size=1) except asyncio.CancelledError: uniform.side_effect = None uniform.reset_mock() @@ -561,7 +561,7 @@ async def test__manage_channel_refresh(self, num_cycles): grpc_helpers_async, "create_channel" ) as create_channel: create_channel.return_value = new_channel - client = self._make_one(project="project-id", use_emulator=False) + client = self._make_client(project="project-id", use_emulator=False) create_channel.reset_mock() try: await client._manage_channel( @@ -714,7 +714,7 @@ async def test__register_instance_state( @pytest.mark.asyncio async def test__remove_instance_registration(self): - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") table = mock.Mock() await client._register_instance("instance-1", table) await client._register_instance("instance-2", table) @@ -752,7 +752,7 @@ async def test__multiple_table_registration(self): """ from google.cloud.bigtable.data._async.client import _WarmedInstanceKey - async with self._make_one(project="project-id") as client: + async with self._make_client(project="project-id") as client: async with client.get_table("instance_1", "table_1") as table_1: instance_1_path = client._gapic_client.instance_path( client.project, "instance_1" @@ -800,7 +800,7 @@ async def test__multiple_instance_registration(self): """ from google.cloud.bigtable.data._async.client import _WarmedInstanceKey - async with self._make_one(project="project-id") as client: + async with self._make_client(project="project-id") as client: async with client.get_table("instance_1", "table_1") as table_1: async with client.get_table("instance_2", "table_2") as table_2: instance_1_path = client._gapic_client.instance_path( @@ -836,7 +836,7 @@ async def test_get_table(self): from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable.data._async.client import _WarmedInstanceKey - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") assert not client._active_instances expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -872,7 +872,7 @@ async def test_get_table_arg_passthrough(self): """ All arguments passed in get_table should be sent to constructor """ - async with self._make_one(project="project-id") as client: + async with self._make_client(project="project-id") as client: with mock.patch( "google.cloud.bigtable.data._async.client.TableAsync.__init__", ) as mock_constructor: @@ -911,7 +911,7 @@ async def test_get_table_context_manager(self): expected_project_id = "project-id" with mock.patch.object(TableAsync, "close") as close_mock: - async with self._make_one(project=expected_project_id) as client: + async with self._make_client(project=expected_project_id) as client: async with client.get_table( expected_instance_id, expected_table_id, @@ -943,11 +943,11 @@ async def test_multiple_pool_sizes(self): # should be able to create multiple clients with different pool sizes without issue pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] for pool_size in pool_sizes: - client = self._make_one( + client = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False ) assert len(client._channel_refresh_tasks) == pool_size - client_duplicate = self._make_one( + client_duplicate = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False ) assert len(client_duplicate._channel_refresh_tasks) == pool_size @@ -962,7 +962,7 @@ async def test_close(self): ) pool_size = 7 - client = self._make_one( + client = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False ) assert len(client._channel_refresh_tasks) == pool_size @@ -984,7 +984,7 @@ async def test_close(self): async def test_close_with_timeout(self): pool_size = 7 expected_timeout = 19 - client = self._make_one(project="project-id", pool_size=pool_size) + client = self._make_client(project="project-id", pool_size=pool_size) tasks = list(client._channel_refresh_tasks) with mock.patch.object(asyncio, "wait_for", AsyncMock()) as wait_for_mock: await client.close(timeout=expected_timeout) @@ -999,7 +999,7 @@ async def test_context_manager(self): # context manager should close the client cleanly close_mock = AsyncMock() true_close = None - async with self._make_one(project="project-id") as client: + async with self._make_client(project="project-id") as client: true_close = client.close() client.close = close_mock for task in client._channel_refresh_tasks: @@ -1016,7 +1016,7 @@ def test_client_ctor_sync(self): # initializing client in a sync context should raise RuntimeError with pytest.warns(RuntimeWarning) as warnings: - client = _make_client(project="project-id", use_emulator=False) + client = self._make_client(project="project-id", use_emulator=False) expected_warning = [w for w in warnings if "client.py" in w.filename] assert len(expected_warning) == 1 assert ( @@ -1028,6 +1028,10 @@ def test_client_ctor_sync(self): class TestTableAsync: + + def _make_client(self, *args, **kwargs): + return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @pytest.mark.asyncio async def test_table_ctor(self): from google.cloud.bigtable.data._async.client import TableAsync @@ -1042,7 +1046,7 @@ async def test_table_ctor(self): expected_read_rows_attempt_timeout = 0.5 expected_mutate_rows_operation_timeout = 2.5 expected_mutate_rows_attempt_timeout = 0.75 - client = _make_client() + client = self._make_client() assert not client._active_instances table = TableAsync( @@ -1101,7 +1105,7 @@ async def test_table_ctor_defaults(self): expected_table_id = "table-id" expected_instance_id = "instance-id" - client = _make_client() + client = self._make_client() assert not client._active_instances table = TableAsync( @@ -1129,7 +1133,7 @@ async def test_table_ctor_invalid_timeout_values(self): """ from google.cloud.bigtable.data._async.client import TableAsync - client = _make_client() + client = self._make_client() timeout_pairs = [ ("default_operation_timeout", "default_attempt_timeout"), @@ -1248,7 +1252,7 @@ async def test_customizable_retryable_errors( down to the gapic layer. """ with mock.patch(retry_fn_path) as retry_fn_mock: - async with _make_client() as client: + async with self._make_client() as client: table = client.get_table("instance-id", "table-id") expected_predicate = lambda a: a in expected_retryables # noqa retry_fn_mock.side_effect = RuntimeError("stop early") @@ -1302,7 +1306,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ f"google.cloud.bigtable_v2.BigtableAsyncClient.{gapic_fn}", mock.AsyncMock() ) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") - async with _make_client() as client: + async with self._make_client() as client: table = TableAsync(client, "instance-id", "table-id", profile) try: test_fn = table.__getattribute__(fn_name) @@ -1330,6 +1334,9 @@ class TestReadRows: Tests for table.read_rows and related methods. """ + def _make_client(self, *args, **kwargs): + return TestBigtableDataClientAsync._make_client(*args, **kwargs) + def _make_table(self, *args, **kwargs): from google.cloud.bigtable.data._async.client import TableAsync @@ -1698,7 +1705,7 @@ async def test_read_rows_default_timeout_override(self): @pytest.mark.asyncio async def test_read_row(self): """Test reading a single row""" - async with _make_client() as client: + async with self._make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1726,7 +1733,7 @@ async def test_read_row(self): @pytest.mark.asyncio async def test_read_row_w_filter(self): """Test reading a single row with an added filter""" - async with _make_client() as client: + async with self._make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1759,7 +1766,7 @@ async def test_read_row_w_filter(self): @pytest.mark.asyncio async def test_read_row_no_response(self): """should return None if row does not exist""" - async with _make_client() as client: + async with self._make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1794,7 +1801,7 @@ async def test_read_row_no_response(self): @pytest.mark.asyncio async def test_row_exists(self, return_value, expected_result): """Test checking for row existence""" - async with _make_client() as client: + async with self._make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1829,9 +1836,13 @@ async def test_row_exists(self, return_value, expected_result): class TestReadRowsSharded: + + def _make_client(self, *args, **kwargs): + return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @pytest.mark.asyncio async def test_read_rows_sharded_empty_query(self): - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as exc: await table.read_rows_sharded([]) @@ -1842,7 +1853,7 @@ async def test_read_rows_sharded_multiple_queries(self): """ Test with multiple queries. Should return results from both """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( table.client._gapic_client, "read_rows" @@ -1868,7 +1879,7 @@ async def test_read_rows_sharded_multiple_queries_calls(self, n_queries): """ Each query should trigger a separate read_rows call """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: query_list = [ReadRowsQuery() for _ in range(n_queries)] @@ -1883,7 +1894,7 @@ async def test_read_rows_sharded_errors(self): from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.cloud.bigtable.data.exceptions import FailedQueryShardError - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = RuntimeError("mock error") @@ -1914,7 +1925,7 @@ async def mock_call(*args, **kwargs): await asyncio.sleep(0.1) return [mock.Mock()] - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -1987,6 +1998,10 @@ async def test_read_rows_sharded_batching(self): class TestSampleRowKeys: + + def _make_client(self, *args, **kwargs): + return TestBigtableDataClientAsync._make_client(*args, **kwargs) + async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): from google.cloud.bigtable_v2.types import SampleRowKeysResponse @@ -2003,7 +2018,7 @@ async def test_sample_row_keys(self): (b"test_2", 100), (b"test_3", 200), ] - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( table.client._gapic_client, "sample_row_keys", AsyncMock() @@ -2023,7 +2038,7 @@ async def test_sample_row_keys_bad_timeout(self): """ should raise error if timeout is negative """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.sample_row_keys(operation_timeout=-1) @@ -2036,7 +2051,7 @@ async def test_sample_row_keys_bad_timeout(self): async def test_sample_row_keys_default_timeout(self): """Should fallback to using table default operation_timeout""" expected_timeout = 99 - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table( "i", "t", @@ -2062,7 +2077,7 @@ async def test_sample_row_keys_gapic_params(self): expected_profile = "test1" instance = "instance_name" table_id = "my_table" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table( instance, table_id, app_profile_id=expected_profile ) as table: @@ -2095,7 +2110,7 @@ async def test_sample_row_keys_retryable_errors(self, retryable_exception): from google.api_core.exceptions import DeadlineExceeded from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( table.client._gapic_client, "sample_row_keys", AsyncMock() @@ -2124,7 +2139,7 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio """ non-retryable errors should cause a raise """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( table.client._gapic_client, "sample_row_keys", AsyncMock() @@ -2135,6 +2150,10 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio class TestMutateRow: + + def _make_client(self, *args, **kwargs): + return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @pytest.mark.asyncio @pytest.mark.parametrize( "mutation_arg", @@ -2156,7 +2175,7 @@ class TestMutateRow: async def test_mutate_row(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2196,7 +2215,7 @@ async def test_mutate_row_retryable_errors(self, retryable_exception): from google.api_core.exceptions import DeadlineExceeded from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2226,7 +2245,7 @@ async def test_mutate_row_non_idempotent_retryable_errors( """ Non-idempotent mutations should not be retried """ - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2254,7 +2273,7 @@ async def test_mutate_row_non_idempotent_retryable_errors( ) @pytest.mark.asyncio async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2277,7 +2296,7 @@ async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): async def test_mutate_row_metadata(self, include_app_profile): """request should attach metadata headers""" profile = "profile" if include_app_profile else None - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("i", "t", app_profile_id=profile) as table: with mock.patch.object( client._gapic_client, "mutate_row", AsyncMock() @@ -2299,7 +2318,7 @@ async def test_mutate_row_metadata(self, include_app_profile): @pytest.mark.parametrize("mutations", [[], None]) @pytest.mark.asyncio async def test_mutate_row_no_mutations(self, mutations): - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.mutate_row("key", mutations=mutations) @@ -2307,6 +2326,10 @@ async def test_mutate_row_no_mutations(self, mutations): class TestBulkMutateRows: + + def _make_client(self, *args, **kwargs): + return TestBigtableDataClientAsync._make_client(*args, **kwargs) + async def _mock_response(self, response_list): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -2355,7 +2378,7 @@ async def generator(): async def test_bulk_mutate_rows(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2379,7 +2402,7 @@ async def test_bulk_mutate_rows(self, mutation_arg): @pytest.mark.asyncio async def test_bulk_mutate_rows_multiple_entries(self): """Test mutations with no errors""" - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2420,7 +2443,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_retryable( MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2466,7 +2489,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable( MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2506,7 +2529,7 @@ async def test_bulk_mutate_idempotent_retryable_request_errors( MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2544,7 +2567,7 @@ async def test_bulk_mutate_rows_non_idempotent_retryable_errors( MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2586,7 +2609,7 @@ async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_excepti MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2622,7 +2645,7 @@ async def test_bulk_mutate_error_index(self): MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2664,7 +2687,7 @@ async def test_bulk_mutate_error_recovery(self): """ from google.api_core.exceptions import DeadlineExceeded - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: table = client.get_table("instance", "table") with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: # fail with a retryable error, then a non-retryable one @@ -2683,13 +2706,17 @@ async def test_bulk_mutate_error_recovery(self): class TestCheckAndMutateRow: + + def _make_client(self, *args, **kwargs): + return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @pytest.mark.parametrize("gapic_result", [True, False]) @pytest.mark.asyncio async def test_check_and_mutate(self, gapic_result): from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse app_profile = "app_profile_id" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table( "instance", "table", app_profile_id=app_profile ) as table: @@ -2729,7 +2756,7 @@ async def test_check_and_mutate(self, gapic_result): @pytest.mark.asyncio async def test_check_and_mutate_bad_timeout(self): """Should raise error if operation_timeout < 0""" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.check_and_mutate_row( @@ -2747,7 +2774,7 @@ async def test_check_and_mutate_single_mutations(self): from google.cloud.bigtable.data.mutations import SetCell from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2775,7 +2802,7 @@ async def test_check_and_mutate_predicate_object(self): mock_predicate = mock.Mock() predicate_pb = {"predicate": "dict"} mock_predicate._to_pb.return_value = predicate_pb - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2803,7 +2830,7 @@ async def test_check_and_mutate_mutations_parsing(self): for idx, mutation in enumerate(mutations): mutation._to_pb.return_value = f"fake {idx}" mutations.append(DeleteAllFromRow()) - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2831,6 +2858,10 @@ async def test_check_and_mutate_mutations_parsing(self): class TestReadModifyWriteRow: + + def _make_client(self, *args, **kwargs): + return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @pytest.mark.parametrize( "call_rules,expected_rules", [ @@ -2857,7 +2888,7 @@ async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules """ Test that the gapic call is called with given rules """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -2871,7 +2902,7 @@ async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules @pytest.mark.parametrize("rules", [[], None]) @pytest.mark.asyncio async def test_read_modify_write_no_rules(self, rules): - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.read_modify_write_row("key", rules=rules) @@ -2883,7 +2914,7 @@ async def test_read_modify_write_call_defaults(self): table_id = "table1" project = "project1" row_key = "row_key1" - async with _make_client(project=project) as client: + async with self._make_client(project=project) as client: async with client.get_table(instance, table_id) as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -2904,7 +2935,7 @@ async def test_read_modify_write_call_overrides(self): row_key = b"row_key1" expected_timeout = 12345 profile_id = "profile1" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table( "instance", "table_id", app_profile_id=profile_id ) as table: @@ -2925,7 +2956,7 @@ async def test_read_modify_write_call_overrides(self): @pytest.mark.asyncio async def test_read_modify_write_string_key(self): row_key = "string_row_key1" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table_id") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -2945,7 +2976,7 @@ async def test_read_modify_write_row_building(self): from google.cloud.bigtable_v2.types import Row as RowPB mock_response = ReadModifyWriteRowResponse(row=RowPB()) - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table_id") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" From dcc0ab695448cd5c41d21635da83cc185f436ff2 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Sat, 13 Apr 2024 10:05:34 -0700 Subject: [PATCH 037/360] added sharded read rows sync --- google/cloud/bigtable/data/_async/client.py | 28 ++++-- google/cloud/bigtable/data/_sync/_autogen.py | 96 +++++++++++++++++++ google/cloud/bigtable/data/_sync/client.py | 17 +++- .../cloud/bigtable/data/_sync/sync_gen.yaml | 3 +- tests/unit/data/_async/test_client.py | 2 + 5 files changed, 135 insertions(+), 11 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 26e439e43..b3ae852dc 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -792,16 +792,16 @@ async def read_rows_sharded( shard_idx = 0 for batch in batched_queries: batch_operation_timeout = next(timeout_generator) - routine_list = [ - self.read_rows( - query, - operation_timeout=batch_operation_timeout, - attempt_timeout=min(attempt_timeout, batch_operation_timeout), - retryable_errors=retryable_errors, - ) + batch_kwargs_list = [ + { + "query": query, + "operation_timeout": batch_operation_timeout, + "attempt_timeout": min(attempt_timeout, batch_operation_timeout), + "retryable_errors": retryable_errors, + } for query in batch ] - batch_result = await asyncio.gather(*routine_list, return_exceptions=True) + batch_result = await self._shard_batch_helper(batch_kwargs_list) for result in batch_result: if isinstance(result, Exception): error_dict[shard_idx] = result @@ -823,6 +823,18 @@ async def read_rows_sharded( ) return results_list + async def _shard_batch_helper(self, kwargs_list: list[dict]) -> list[list[Row] | BaseException]: + """ + Helper function for executing a batch of read_rows queries in parallel + + Sync client implementation will override this method + """ + routine_list = [ + self.read_rows(**kwargs) for kwargs in kwargs_list + ] + return await asyncio.gather(*routine_list, return_exceptions=True) + + async def row_exists( self, row_key: str | bytes, diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 6265d315f..cc80ceaed 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -47,7 +47,9 @@ from google.cloud.bigtable.data._async._read_rows import _ResetRow from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data._helpers import RowKeySamples +from google.cloud.bigtable.data._helpers import ShardedQuery from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT from google.cloud.bigtable.data._helpers import _WarmedInstanceKey from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _get_retryable_errors @@ -56,8 +58,10 @@ from google.cloud.bigtable.data._helpers import _retry_exception_factory from google.cloud.bigtable.data._helpers import _validate_timeouts from google.cloud.bigtable.data.exceptions import FailedMutationEntryError +from google.cloud.bigtable.data.exceptions import FailedQueryShardError from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup +from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.cloud.bigtable.data.exceptions import _RowSetComplete from google.cloud.bigtable.data.mutations import Mutation from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -1352,6 +1356,98 @@ def read_row( return None return results[0] + def read_rows_sharded( + self, + sharded_query: ShardedQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> list[Row]: + """ + Runs a sharded query in parallel, then return the results in a single list. + Results will be returned in the order of the input queries. + + This function is intended to be run on the results on a query.shard() call: + + ``` + table_shard_keys = await table.sample_row_keys() + query = ReadRowsQuery(...) + shard_queries = query.shard(table_shard_keys) + results = await table.read_rows_sharded(shard_queries) + ``` + + Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not + yet recommended for production use. + + Args: + - sharded_query: a sharded query to execute + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Raises: + - ShardedReadRowsExceptionGroup: if any of the queries failed + - ValueError: if the query_list is empty + """ + if not sharded_query: + raise ValueError("empty sharded_query") + (operation_timeout, attempt_timeout) = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + timeout_generator = _attempt_timeout_generator( + operation_timeout, operation_timeout + ) + batched_queries = [ + sharded_query[i : i + _CONCURRENCY_LIMIT] + for i in range(0, len(sharded_query), _CONCURRENCY_LIMIT) + ] + results_list = [] + error_dict = {} + shard_idx = 0 + for batch in batched_queries: + batch_operation_timeout = next(timeout_generator) + batch_kwargs_list = [ + { + "query": query, + "operation_timeout": batch_operation_timeout, + "attempt_timeout": min(attempt_timeout, batch_operation_timeout), + "retryable_errors": retryable_errors, + } + for query in batch + ] + batch_result = self._shard_batch_helper(batch_kwargs_list) + for result in batch_result: + if isinstance(result, Exception): + error_dict[shard_idx] = result + elif isinstance(result, BaseException): + raise result + else: + results_list.extend(result) + shard_idx += 1 + if error_dict: + raise ShardedReadRowsExceptionGroup( + [ + FailedQueryShardError(idx, sharded_query[idx], e) + for (idx, e) in error_dict.items() + ], + results_list, + len(sharded_query), + ) + return results_list + + def _shard_batch_helper( + self, kwargs_list: list[dict] + ) -> list[list[Row] | BaseException]: + raise NotImplementedError("Function not implemented in sync class") + def row_exists( self, row_key: str | bytes, diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 54d143cef..c12a5b968 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -14,15 +14,19 @@ # from __future__ import annotations -from typing import Any +from typing import Any, TYPE_CHECKING import grpc import google.auth.credentials +import concurrent.futures from google.cloud.bigtable.data._sync._autogen import BigtableDataClient_SyncGen from google.cloud.bigtable.data._sync._autogen import Table_SyncGen +if TYPE_CHECKING: + from google.cloud.bigtable.data.row import Row + class BigtableDataClient(BigtableDataClient_SyncGen): def __init__( @@ -60,3 +64,14 @@ class Table(Table_SyncGen): def _register_with_client(self): self.client._register_instance(self.instance_id, self) self._register_instance_task = None + + def _shard_batch_helper(self, kwargs_list: list[dict]) -> list[list[Row] | BaseException]: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures_list = [executor.submit(self.read_rows, **kwargs) for kwargs in kwargs_list] + results_list: list[list[Row] | BaseException] = [] + for future in concurrent.futures.as_completed(futures_list): + if future.exception(): + results_list.append(future.exception()) + else: + results_list.append(future.result()) + return results_list diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index a34ae832c..23568c0e4 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -54,7 +54,6 @@ classes: # Specify transformations for individual classes - path: google.cloud.bigtable.data._async.client.TableAsync autogen_sync_name: Table_SyncGen concrete_path: google.cloud.bigtable.data._sync.client.Table - drop_methods: ["read_rows_sharded"] - error_methods: ["_register_with_client"] + error_methods: ["_register_with_client", "_shard_batch_helper"] save_path: "google/cloud/bigtable/data/_sync/_autogen.py" diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 065b00b07..7ca8fdc21 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1944,6 +1944,7 @@ async def test_read_rows_sharded_batching(self): Large queries should be processed in batches to limit concurrency operation timeout should change between batches """ + import functools from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT @@ -1958,6 +1959,7 @@ async def test_read_rows_sharded_batching(self): start_attempt_timeout = 3 table_mock.default_read_rows_operation_timeout = start_operation_timeout table_mock.default_read_rows_attempt_timeout = start_attempt_timeout + table_mock._shard_batch_helper = functools.partial(TableAsync._shard_batch_helper, table_mock) # clock ticks one second on each check with mock.patch("time.monotonic", side_effect=range(0, 100000)): with mock.patch("asyncio.gather", AsyncMock()) as gather_mock: From 10c565f929c4fdaf803af4e4126028cc8e3e623c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Sat, 13 Apr 2024 10:38:19 -0700 Subject: [PATCH 038/360] made tweaks to read rows sharded --- google/cloud/bigtable/data/_sync/__init__.py | 0 google/cloud/bigtable/data/_sync/client.py | 6 +- .../cloud/bigtable/data/_sync/unit_tests.yaml | 70 +++++++++++++++++++ tests/unit/data/_async/test_client.py | 4 +- 4 files changed, 77 insertions(+), 3 deletions(-) create mode 100644 google/cloud/bigtable/data/_sync/__init__.py create mode 100644 google/cloud/bigtable/data/_sync/unit_tests.yaml diff --git a/google/cloud/bigtable/data/_sync/__init__.py b/google/cloud/bigtable/data/_sync/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index c12a5b968..e4a8ab193 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -24,6 +24,10 @@ from google.cloud.bigtable.data._sync._autogen import BigtableDataClient_SyncGen from google.cloud.bigtable.data._sync._autogen import Table_SyncGen +# import required so Table_SyncGen can create _MutateRowsOperation and _ReadRowsOperation +import google.cloud.bigtable.data._sync._read_rows # noqa: F401 +import google.cloud.bigtable.data._sync._mutate_rows # noqa: F401 + if TYPE_CHECKING: from google.cloud.bigtable.data.row import Row @@ -69,7 +73,7 @@ def _shard_batch_helper(self, kwargs_list: list[dict]) -> list[list[Row] | BaseE with concurrent.futures.ThreadPoolExecutor() as executor: futures_list = [executor.submit(self.read_rows, **kwargs) for kwargs in kwargs_list] results_list: list[list[Row] | BaseException] = [] - for future in concurrent.futures.as_completed(futures_list): + for future in futures_list: if future.exception(): results_list.append(future.exception()) else: diff --git a/google/cloud/bigtable/data/_sync/unit_tests.yaml b/google/cloud/bigtable/data/_sync/unit_tests.yaml new file mode 100644 index 000000000..ccb690a74 --- /dev/null +++ b/google/cloud/bigtable/data/_sync/unit_tests.yaml @@ -0,0 +1,70 @@ +asyncio_replacements: # Replace entire modules + sleep: time.sleep + Queue: queue.Queue + Condition: threading.Condition + Future: concurrent.futures.Future + create_task: threading.Thread + +added_imports: + - "import google.api_core.exceptions as core_exceptions" + - "import threading" + - "import concurrent.futures" + - "from google.cloud.bigtable.data import Table" + +text_replacements: # Find and replace specific text patterns + __anext__: __next__ + __aiter__: __iter__ + __aenter__: __enter__ + __aexit__: __exit__ + aclose: close + AsyncIterable: Iterable + AsyncIterator: Iterator + StopAsyncIteration: StopIteration + Awaitable: None + BigtableDataClientAsync: BigtableDataClient + TableAsync: Table + AsyncMock: mock.Mock + retry_target_async: retry_target + TestBigtableDataClientAsync: TestBigtableDataClient + +classes: + #- path: tests.unit.data._async.test__mutate_rows.TestMutateRowsOperation + # autogen_sync_name: TestMutateRowsOperation + # replace_methods: + # _target_class: | + # from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation + # return _MutateRowsOperation + #- path: tests.unit.data._async.test__read_rows.TestReadRowsOperation + # autogen_sync_name: TestReadRowsOperation + # text_replacements: + # test_aclose: test_close + # replace_methods: + # _get_target_class: | + # from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + # return _ReadRowsOperation + #- path: tests.unit.data._async.test_mutations_batcher.Test_FlowControl + # autogen_sync_name: Test_FlowControl + # replace_methods: + # _target_class: | + # from google.cloud.bigtable.data._sync.mutations_batcher import _FlowControl + # return _FlowControl + #- path: tests.unit.data._async.test_mutations_batcher.TestMutationsBatcherAsync + # autogen_sync_name: TestMutationsBatcher + # replace_methods: + # _get_target_class: | + # from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher + # return MutationsBatcher + # is_async: "return False" + - path: tests.unit.data._async.test_client.TestBigtableDataClientAsync + autogen_sync_name: TestBigtableDataClient + replace_methods: + _get_target_class: | + from google.cloud.bigtable.data._sync.client import BigtableDataClient + return BigtableDataClient + drop_methods: ["test_client_ctor_sync", "test_channel_pool_creation", "test_channel_pool_rotation", "test_channel_pool_replace"] + - path: tests.unit.data._async.test_client.TestReadRowsShardedAsync + autogen_sync_name: TestReadRowsSharded + - path: tests.unit.data._async.test_client.TestReadRowsAsync + autogen_sync_name: TestReadRows + +save_path: "tests/unit/data/_sync/test_autogen.py" diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 7ca8fdc21..4a5bfca29 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1329,7 +1329,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ assert "app_profile_id=" not in goog_metadata -class TestReadRows: +class TestReadRowsAsync: """ Tests for table.read_rows and related methods. """ @@ -1835,7 +1835,7 @@ async def test_row_exists(self, return_value, expected_result): assert query.filter._to_dict() == expected_filter -class TestReadRowsSharded: +class TestReadRowsShardedAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) From 2cdf0ab6f3ae83b2034c152838b464e46ffc7e0f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Apr 2024 11:55:12 -0700 Subject: [PATCH 039/360] got client tests passing --- google/cloud/bigtable/data/_async/client.py | 2 +- google/cloud/bigtable/data/_sync/_autogen.py | 27 + google/cloud/bigtable/data/_sync/client.py | 167 +- .../cloud/bigtable/data/_sync/sync_gen.yaml | 3 +- .../cloud/bigtable/data/_sync/unit_tests.yaml | 30 +- .../bigtable_v2/services/bigtable/client.py | 2 + .../bigtable/transports/pooled_grpc.py | 438 ++++ .../transports/pooled_grpc_asyncio.py | 17 +- sync_surface_generator.py | 3 + tests/unit/data/_async/test_client.py | 220 +- tests/unit/data/_sync/test_autogen.py | 1864 ++++++++++------- 11 files changed, 1858 insertions(+), 915 deletions(-) create mode 100644 google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index b3ae852dc..b7090cc59 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -328,7 +328,7 @@ async def _manage_channel( # cycle channel out of use, with long grace window before closure start_timestamp = time.time() await self.transport.replace_channel( - channel_idx, grace=grace_period, swap_sleep=10, new_channel=new_channel + channel_idx, grace=grace_period, new_channel=new_channel ) # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index cc80ceaed..9de2ef21f 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -978,6 +978,33 @@ def _ping_and_warm_instances( - sequence of results or exceptions from the ping requests """ + def _manage_channel( + self, + channel_idx: int, + refresh_interval_min: float = 60 * 35, + refresh_interval_max: float = 60 * 45, + grace_period: float = 60 * 10, + ) -> None: + """ + Background coroutine that periodically refreshes and warms a grpc channel + + The backend will automatically close channels after 60 minutes, so + `refresh_interval` + `grace_period` should be < 60 minutes + + Runs continuously until the client is closed + + Args: + channel_idx: index of the channel in the transport's channel pool + refresh_interval_min: minimum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + refresh_interval_max: maximum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + grace_period: time to allow previous channel to serve existing + requests before closing, in seconds + """ + def _register_instance( self, instance_id: str, owner: google.cloud.bigtable.data._sync.client.Table ) -> None: diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index e4a8ab193..06a9dbb6d 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -16,7 +16,9 @@ from typing import Any, TYPE_CHECKING -import grpc +import time +import random +import threading import google.auth.credentials import concurrent.futures @@ -28,39 +30,166 @@ import google.cloud.bigtable.data._sync._read_rows # noqa: F401 import google.cloud.bigtable.data._sync._mutate_rows # noqa: F401 +from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledBigtableGrpcTransport, + PooledChannel +) +from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest + if TYPE_CHECKING: + import grpc from google.cloud.bigtable.data.row import Row + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey class BigtableDataClient(BigtableDataClient_SyncGen): - def __init__( - self, - *, - project: str | None = None, - credentials: google.auth.credentials.Credentials | None = None, - client_options: dict[str, Any] - | "google.api_core.client_options.ClientOptions" - | None = None, - **kwargs - ): - # remove pool size option in sync client - super().__init__( - project=project, credentials=credentials, client_options=client_options, pool_size=1 - ) + + @property + def _executor(self) -> concurrent.futures.ThreadPoolExecutor: + if not hasattr(self, "_executor_instance"): + self._executor_instance = concurrent.futures.ThreadPoolExecutor() + return self._executor_instance + + @property + def _is_closed(self) -> threading.Event: + if not hasattr(self, "_is_closed_instance"): + self._is_closed_instance = threading.Event() + return self._is_closed_instance def _transport_init(self, pool_size: int) -> str: - return "grpc" + transport_str = f"pooled_grpc_{pool_size}" + transport = PooledBigtableGrpcTransport.with_fixed_size(pool_size) + BigtableClientMeta._transport_registry[transport_str] = transport + return transport_str def _prep_emulator_channel(self, host:str, pool_size: int) -> str: - self.transport._grpc_channel = grpc.insecure_channel(target=host) + self.transport._grpc_channel = PooledChannel( + pool_size=pool_size, + host=host, + insecure=True, + ) @staticmethod def _client_version() -> str: return f"{google.cloud.bigtable.__version__}-data" def _start_background_channel_refresh(self) -> None: - # TODO: implement channel refresh - pass + if not self._channel_refresh_tasks and not self._emulator_host: + for channel_idx in range(self.transport.pool_size): + self._channel_refresh_tasks.append( + self._executor.submit(self._manage_channel, channel_idx) + ) + + def _manage_channel( + self, + channel_idx: int, + refresh_interval_min: float = 60 * 35, + refresh_interval_max: float = 60 * 45, + grace_period: float = 60 * 10, + ) -> None: + """ + Background routine that periodically refreshes and warms a grpc channel + + The backend will automatically close channels after 60 minutes, so + `refresh_interval` + `grace_period` should be < 60 minutes + + Runs continuously until the client is closed + + Args: + channel_idx: index of the channel in the transport's channel pool + refresh_interval_min: minimum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + refresh_interval_max: maximum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + grace_period: time to allow previous channel to serve existing + requests before closing, in seconds + """ + first_refresh = self._channel_init_time + random.uniform( + refresh_interval_min, refresh_interval_max + ) + next_sleep = max(first_refresh - time.monotonic(), 0) + if next_sleep > 0: + # warm the current channel immediately + channel = self.transport.channels[channel_idx] + self._ping_and_warm_instances(channel) + # continuously refresh the channel every `refresh_interval` seconds + while not self._is_closed.is_set(): + # sleep until next refresh, or until client is closed + self._is_closed.wait(next_sleep) + if self._is_closed.is_set(): + break + # prepare new channel for use + new_channel = self.transport.grpc_channel._create_channel() + self._ping_and_warm_instances(new_channel) + # cycle channel out of use, with long grace window before closure + start_timestamp = time.monotonic() + self.transport.replace_channel( + channel_idx, grace=grace_period, new_channel=new_channel, event=self._is_closed + ) + # subtract the time spent waiting for the channel to be replaced + next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) + next_sleep = next_refresh - (time.monotonic() - start_timestamp) + + def _ping_and_warm_instances( + self, channel: grpc.Channel, instance_key: _WarmedInstanceKey | None = None + ) -> list[BaseException | None]: + """ + Prepares the backend for requests on a channel + + Pings each Bigtable instance registered in `_active_instances` on the client + + Args: + - channel: grpc channel to warm + - instance_key: if provided, only warm the instance associated with the key + Returns: + - sequence of results or exceptions from the ping requests + """ + instance_list = ( + [instance_key] if instance_key is not None else self._active_instances + ) + ping_rpc = channel.unary_unary( + "/google.bigtable.v2.Bigtable/PingAndWarm", + request_serializer=PingAndWarmRequest.serialize, + ) + # execute pings in parallel + futures_list = [] + for (instance_name, table_name, app_profile_id) in instance_list: + future = self._executor.submit( + ping_rpc, + request={"name": instance_name, "app_profile_id": app_profile_id}, + metadata=[ + ( + "x-goog-request-params", + f"name={instance_name}&app_profile_id={app_profile_id}", + ) + ], + wait_for_ready=True, + ) + futures_list.append(future) + results_list = [] + for future in futures_list: + try: + future.result() + results_list.append(None) + except BaseException as e: + results_list.append(e) + return results_list + + def close(self) -> None: + """ + Close the client and all associated resources + + This method should be called when the client is no longer needed. + """ + self._is_closed.set() + with self._executor: + self._executor.shutdown(wait=False) + self._channel_refresh_tasks = [] + self.transport.close() + super().close() class Table(Table_SyncGen): diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index 23568c0e4..8ef268068 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -48,8 +48,7 @@ classes: # Specify transformations for individual classes - path: google.cloud.bigtable.data._async.client.BigtableDataClientAsync autogen_sync_name: BigtableDataClient_SyncGen concrete_path: google.cloud.bigtable.data._sync.client.BigtableDataClient - pass_methods: ["close", "_ping_and_warm_instances"] - drop_methods: ["_manage_channel"] + pass_methods: ["close", "_ping_and_warm_instances", "_manage_channel"] error_methods: ["_start_background_channel_refresh", "_client_version", "_prep_emulator_channel", "_transport_init"] - path: google.cloud.bigtable.data._async.client.TableAsync autogen_sync_name: Table_SyncGen diff --git a/google/cloud/bigtable/data/_sync/unit_tests.yaml b/google/cloud/bigtable/data/_sync/unit_tests.yaml index ccb690a74..caad4ab15 100644 --- a/google/cloud/bigtable/data/_sync/unit_tests.yaml +++ b/google/cloud/bigtable/data/_sync/unit_tests.yaml @@ -22,10 +22,14 @@ text_replacements: # Find and replace specific text patterns StopAsyncIteration: StopIteration Awaitable: None BigtableDataClientAsync: BigtableDataClient + BigtableAsyncClient: BigtableClient TableAsync: Table AsyncMock: mock.Mock retry_target_async: retry_target TestBigtableDataClientAsync: TestBigtableDataClient + assert_awaited_once: assert_called_once + assert_awaited: assert_called_once + grpc_helpers_async: grpc_helpers classes: #- path: tests.unit.data._async.test__mutate_rows.TestMutateRowsOperation @@ -57,14 +61,30 @@ classes: # is_async: "return False" - path: tests.unit.data._async.test_client.TestBigtableDataClientAsync autogen_sync_name: TestBigtableDataClient + added_imports: + - "from google.api_core import grpc_helpers" + - "from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient" + - "from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport" + - "from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledChannel" + text_replacements: + PooledBigtableGrpcAsyncIOTransport: PooledBigtableGrpcTransport + PooledChannelAsync: PooledChannel + TestTableAsync: TestTable replace_methods: _get_target_class: | from google.cloud.bigtable.data._sync.client import BigtableDataClient return BigtableDataClient - drop_methods: ["test_client_ctor_sync", "test_channel_pool_creation", "test_channel_pool_rotation", "test_channel_pool_replace"] - - path: tests.unit.data._async.test_client.TestReadRowsShardedAsync - autogen_sync_name: TestReadRowsSharded - - path: tests.unit.data._async.test_client.TestReadRowsAsync - autogen_sync_name: TestReadRows + is_async: "return False" + drop_methods: ["test_client_ctor_sync", "test__start_background_channel_refresh_sync", "test__start_background_channel_refresh_tasks_names", "test_close_with_timeout"] + - path: tests.unit.data._async.test_client.TestTableAsync + autogen_sync_name: TestTable + replace_methods: + _get_target_class: | + from google.cloud.bigtable.data._sync.client import Table + return Table + #- path: tests.unit.data._async.test_client.TestReadRowsShardedAsync + # autogen_sync_name: TestReadRowsSharded + #- path: tests.unit.data._async.test_client.TestReadRowsAsync + # autogen_sync_name: TestReadRows save_path: "tests/unit/data/_sync/test_autogen.py" diff --git a/google/cloud/bigtable_v2/services/bigtable/client.py b/google/cloud/bigtable_v2/services/bigtable/client.py index f53f25e90..dc51fdfa1 100644 --- a/google/cloud/bigtable_v2/services/bigtable/client.py +++ b/google/cloud/bigtable_v2/services/bigtable/client.py @@ -55,6 +55,7 @@ from .transports.grpc import BigtableGrpcTransport from .transports.grpc_asyncio import BigtableGrpcAsyncIOTransport from .transports.pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport +from .transports.pooled_grpc import PooledBigtableGrpcTransport from .transports.rest import BigtableRestTransport @@ -70,6 +71,7 @@ class BigtableClientMeta(type): _transport_registry["grpc"] = BigtableGrpcTransport _transport_registry["grpc_asyncio"] = BigtableGrpcAsyncIOTransport _transport_registry["pooled_grpc_asyncio"] = PooledBigtableGrpcAsyncIOTransport + _transport_registry["pooled_grpc"] = PooledBigtableGrpcTransport _transport_registry["rest"] = BigtableRestTransport def get_transport_class( diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py new file mode 100644 index 000000000..f852017c0 --- /dev/null +++ b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py @@ -0,0 +1,438 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from functools import partialmethod +from functools import partial +import time +from typing import ( + Awaitable, + Callable, + Dict, + Optional, + Sequence, + Tuple, + Union, + List, + Type, +) + +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.bigtable_v2.types import bigtable +from .base import BigtableTransport, DEFAULT_CLIENT_INFO +from .grpc import BigtableGrpcTransport + + +class PooledMultiCallable: + def __init__(self, channel_pool: "PooledChannel", *args, **kwargs): + self._init_args = args + self._init_kwargs = kwargs + self.next_channel_fn = channel_pool.next_channel + + def with_call(self, *args, **kwargs): + raise NotImplementedError() + + def future(self, *args, **kwargs): + raise NotImplementedError() + +class PooledUnaryUnaryMultiCallable(PooledMultiCallable, grpc.UnaryUnaryMultiCallable): + def __call__(self, *args, **kwargs): + return self.next_channel_fn().unary_unary( + *self._init_args, **self._init_kwargs + )(*args, **kwargs) + + +class PooledUnaryStreamMultiCallable(PooledMultiCallable, grpc.UnaryStreamMultiCallable): + def __call__(self, *args, **kwargs): + return self.next_channel_fn().unary_stream( + *self._init_args, **self._init_kwargs + )(*args, **kwargs) + + +class PooledStreamUnaryMultiCallable(PooledMultiCallable, grpc.StreamUnaryMultiCallable): + def __call__(self, *args, **kwargs): + return self.next_channel_fn().stream_unary( + *self._init_args, **self._init_kwargs + )(*args, **kwargs) + + +class PooledStreamStreamMultiCallable( + PooledMultiCallable, grpc.StreamStreamMultiCallable +): + def __call__(self, *args, **kwargs): + return self.next_channel_fn().stream_stream( + *self._init_args, **self._init_kwargs + )(*args, **kwargs) + + +class PooledChannel(grpc.Channel): + def __init__( + self, + pool_size: int = 3, + host: str = "bigtable.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + quota_project_id: Optional[str] = None, + default_scopes: Optional[Sequence[str]] = None, + scopes: Optional[Sequence[str]] = None, + default_host: Optional[str] = None, + insecure: bool = False, + **kwargs, + ): + self._pool: List[grpc.Channel] = [] + self._next_idx = 0 + if insecure: + self._create_channel = partial(grpc.insecure_channel, host) + else: + self._create_channel = partial( + grpc_helpers.create_channel, + target=host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=default_scopes, + scopes=scopes, + default_host=default_host, + **kwargs, + ) + for i in range(pool_size): + self._pool.append(self._create_channel()) + + def next_channel(self) -> grpc.Channel: + channel = self._pool[self._next_idx] + self._next_idx = (self._next_idx + 1) % len(self._pool) + return channel + + def unary_unary(self, *args, **kwargs) -> grpc.UnaryUnaryMultiCallable: + return PooledUnaryUnaryMultiCallable(self, *args, **kwargs) + + def unary_stream(self, *args, **kwargs) -> grpc.UnaryStreamMultiCallable: + return PooledUnaryStreamMultiCallable(self, *args, **kwargs) + + def stream_unary(self, *args, **kwargs) -> grpc.StreamUnaryMultiCallable: + return PooledStreamUnaryMultiCallable(self, *args, **kwargs) + + def stream_stream(self, *args, **kwargs) -> grpc.StreamStreamMultiCallable: + return PooledStreamStreamMultiCallable(self, *args, **kwargs) + + def close(self): + for channel in self._pool: + channel.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: + raise NotImplementedError() + + def wait_for_state_change(self, last_observed_state): + raise NotImplementedError() + + def subscribe(self, callback, try_to_connect: bool = False) -> grpc.ChannelConnectivity: + raise NotImplementedError() + + def unsubscribe(self, callback): + raise NotImplementedError() + + def replace_channel( + self, channel_idx, grace=1, new_channel=None, event=None + ) -> grpc.Channel: + """ + Replaces a channel in the pool with a fresh one. + + The `new_channel` will start processing new requests immidiately, + but the old channel will continue serving existing clients for + `grace` seconds + + Args: + channel_idx(int): the channel index in the pool to replace + grace(Optional[float]): The time to wait for active RPCs to + finish. If a grace period is not specified (by passing None for + grace), all existing RPCs are cancelled immediately. + new_channel(grpc.Channel): a new channel to insert into the pool + at `channel_idx`. If `None`, a new channel will be created. + event(Optional[threading.Event]): an event to signal when the + replacement should be aborted. If set, will call `event.wait()` + instead of the `time.sleep` function. + """ + if channel_idx >= len(self._pool) or channel_idx < 0: + raise ValueError( + f"invalid channel_idx {channel_idx} for pool size {len(self._pool)}" + ) + if new_channel is None: + new_channel = self._create_channel() + old_channel = self._pool[channel_idx] + self._pool[channel_idx] = new_channel + if event: + event.wait(grace) + else: + time.sleep(grace) + old_channel.close() + return new_channel + + +class PooledBigtableGrpcTransport(BigtableGrpcTransport): + """Pooled gRPC backend transport for Bigtable. + + Service for reading from and writing to existing Bigtable + tables. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + + This class allows channel pooling, so multiple channels can be used concurrently + when making requests. Channels are rotated in a round-robin fashion. + """ + + @classmethod + def with_fixed_size(cls, pool_size) -> Type["PooledBigtableGrpcTransport"]: + """ + Creates a new class with a fixed channel pool size. + + A fixed channel pool makes compatibility with other transports easier, + as the initializer signature is the same. + """ + + class PooledTransportFixed(cls): + __init__ = partialmethod(cls.__init__, pool_size=pool_size) + + PooledTransportFixed.__name__ = f"{cls.__name__}_{pool_size}" + PooledTransportFixed.__qualname__ = PooledTransportFixed.__name__ + return PooledTransportFixed + + @classmethod + def create_channel( + cls, + pool_size: int = 3, + host: str = "bigtable.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a PooledChannel object, representing a pool of gRPC channels + Args: + pool_size (int): The number of channels in the pool. + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + PooledChannel: a channel pool object + """ + + return PooledChannel( + pool_size, + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + pool_size: int = 3, + host: str = "bigtable.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + pool_size (int): the number of grpc channels to maintain in a pool + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + ValueError: if ``pool_size`` <= 0 + """ + if pool_size <= 0: + raise ValueError(f"invalid pool_size: {pool_size}") + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + BigtableTransport.__init__( + self, + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + self._quota_project_id = quota_project_id + self._grpc_channel = type(self).create_channel( + pool_size, + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=self._quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def pool_size(self) -> int: + """The number of grpc channels in the pool.""" + return len(self._grpc_channel._pool) + + @property + def channels(self) -> List[grpc.Channel]: + """Acccess the internal list of grpc channels.""" + return self._grpc_channel._pool + + def replace_channel( + self, channel_idx, grace=1, new_channel=None, event=None + ) -> grpc.Channel: + """ + Replaces a channel in the pool with a fresh one. + + The `new_channel` will start processing new requests immidiately, + but the old channel will continue serving existing clients for `grace` seconds + + Args: + channel_idx(int): the channel index in the pool to replace + grace(Optional[float]): The time to wait for active RPCs to + finish. If a grace period is not specified (by passing None for + grace), all existing RPCs are cancelled immediately. + new_channel(grpc.Channel): a new channel to insert into the pool + at `channel_idx`. If `None`, a new channel will be created. + event(Optional[threading.Event]): an event to signal when the + replacement should be aborted. If set, will call `event.wait()` + instead of the `time.sleep` function. + """ + return self._grpc_channel.replace_channel( + channel_idx=channel_idx, grace=grace, new_channel=new_channel, event=event + ) + + +__all__ = ("PooledBigtableGrpcTransport",) diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py index 372e5796d..fa7ab4f59 100644 --- a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py +++ b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py @@ -150,7 +150,7 @@ async def wait_for_state_change(self, last_observed_state): raise NotImplementedError() async def replace_channel( - self, channel_idx, grace=None, swap_sleep=1, new_channel=None + self, channel_idx, grace=1, new_channel=None ) -> aio.Channel: """ Replaces a channel in the pool with a fresh one. @@ -160,11 +160,9 @@ async def replace_channel( Args: channel_idx(int): the channel index in the pool to replace - grace(Optional[float]): The time to wait until all active RPCs are - finished. If a grace period is not specified (by passing None for + grace(Optional[float]): The time to wait for active RPCs to + finish. If a grace period is not specified (by passing None for grace), all existing RPCs are cancelled immediately. - swap_sleep(Optional[float]): The number of seconds to sleep in between - replacing channels and closing the old one new_channel(grpc.aio.Channel): a new channel to insert into the pool at `channel_idx`. If `None`, a new channel will be created. """ @@ -176,7 +174,6 @@ async def replace_channel( new_channel = self._create_channel() old_channel = self._pool[channel_idx] self._pool[channel_idx] = new_channel - await asyncio.sleep(swap_sleep) await old_channel.close(grace=grace) return new_channel @@ -400,7 +397,7 @@ def channels(self) -> List[grpc.Channel]: return self._grpc_channel._pool async def replace_channel( - self, channel_idx, grace=None, swap_sleep=1, new_channel=None + self, channel_idx, grace=1, new_channel=None ) -> aio.Channel: """ Replaces a channel in the pool with a fresh one. @@ -410,16 +407,14 @@ async def replace_channel( Args: channel_idx(int): the channel index in the pool to replace - grace(Optional[float]): The time to wait until all active RPCs are + grace(Optional[float]): The time to wait for active RPCs to finished. If a grace period is not specified (by passing None for grace), all existing RPCs are cancelled immediately. - swap_sleep(Optional[float]): The number of seconds to sleep in between - replacing channels and closing the old one new_channel(grpc.aio.Channel): a new channel to insert into the pool at `channel_idx`. If `None`, a new channel will be created. """ return await self._grpc_channel.replace_channel( - channel_idx, grace, swap_sleep, new_channel + channel_idx=channel_idx, grace=grace, new_channel=new_channel ) diff --git a/sync_surface_generator.py b/sync_surface_generator.py index d722de008..59856488a 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -334,6 +334,9 @@ def transform_from_config(config_dict: dict): # add globals to class_dict class_dict["asyncio_replacements"] = {**config_dict.get("asyncio_replacements", {}), **class_dict.get("asyncio_replacements", {})} class_dict["text_replacements"] = {**global_text_replacements, **class_dict.get("text_replacements", {})} + # add class-specific imports + for import_str in class_dict.pop("added_imports", []): + combined_imports.add(ast.parse(import_str).body[0]) # transform class tree_body, imports = transform_class(class_object, **class_dict) # update combined data diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 4a5bfca29..c761dd186 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -20,9 +20,17 @@ import pytest +from google.api_core import grpc_helpers_async from google.cloud.bigtable.data import mutations from google.auth.credentials import AnonymousCredentials from google.cloud.bigtable_v2.types import ReadRowsResponse +from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, +) +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, +) +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import PooledChannel as PooledChannelAsync from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.api_core import exceptions as core_exceptions from google.cloud.bigtable.data.exceptions import InvalidChunk @@ -40,10 +48,6 @@ import mock # type: ignore from mock import AsyncMock # type: ignore -VENEER_HEADER_REGEX = re.compile( - r"gapic\/[0-9]+\.[\w.-]+ gax\/[0-9]+\.[\w.-]+ gccl\/[0-9]+\.[\w.-]+-data-async gl-python\/[0-9]+\.[\w.-]+ grpc\/[0-9]+\.[\w.-]+" -) - class TestBigtableDataClientAsync: @@ -71,6 +75,10 @@ def _make_client(cls, *args, use_emulator=True, **kwargs): with mock.patch.dict(os.environ, env_mask): return cls._get_target_class()(*args, **kwargs) + @property + def is_async(self): + return True + @pytest.mark.asyncio async def test_ctor(self): expected_project = "project-id" @@ -92,9 +100,6 @@ async def test_ctor(self): @pytest.mark.asyncio async def test_ctor_super_inits(self): - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) from google.cloud.client import ClientWithProject from google.api_core import client_options as client_options_lib @@ -103,7 +108,8 @@ async def test_ctor_super_inits(self): credentials = AnonymousCredentials() client_options = {"api_endpoint": "foo.bar:1234"} options_parsed = client_options_lib.from_dict(client_options) - transport_str = f"pooled_grpc_asyncio_{pool_size}" + asyncio_portion = "_asyncio" if self.is_async else "" + transport_str = f"pooled_grpc{asyncio_portion}_{pool_size}" with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: bigtable_client_init.return_value = None with mock.patch.object( @@ -135,9 +141,6 @@ async def test_ctor_super_inits(self): @pytest.mark.asyncio async def test_ctor_dict_options(self): - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) from google.api_core.client_options import ClientOptions client_options = {"api_endpoint": "foo.bar:1234"} @@ -160,9 +163,17 @@ async def test_ctor_dict_options(self): @pytest.mark.asyncio async def test_veneer_grpc_headers(self): + client_component = "data-async" if self.is_async else "data" + VENEER_HEADER_REGEX = re.compile( + r"gapic\/[0-9]+\.[\w.-]+ gax\/[0-9]+\.[\w.-]+ gccl\/[0-9]+\.[\w.-]+-" + client_component + r" gl-python\/[0-9]+\.[\w.-]+ grpc\/[0-9]+\.[\w.-]+" + ) + # client_info should be populated with headers to # detect as a veneer client - patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") + if self.is_async: + patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") + else: + patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") with patch as gapic_mock: client = self._make_client(project="project-id") wrapped_call_list = gapic_mock.call_args_list @@ -182,10 +193,7 @@ async def test_veneer_grpc_headers(self): @pytest.mark.asyncio async def test_channel_pool_creation(self): pool_size = 14 - with mock.patch( - "google.api_core.grpc_helpers_async.create_channel" - ) as create_channel: - create_channel.return_value = AsyncMock() + with mock.patch.object(grpc_helpers_async, "create_channel", AsyncMock()) as create_channel: client = self._make_client(project="project-id", pool_size=pool_size) assert create_channel.call_count == pool_size await client.close() @@ -198,13 +206,9 @@ async def test_channel_pool_creation(self): @pytest.mark.asyncio async def test_channel_pool_rotation(self): - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel, - ) - pool_size = 7 - with mock.patch.object(PooledChannel, "next_channel") as next_channel: + with mock.patch.object(PooledChannelAsync, "next_channel") as next_channel: client = self._make_client(project="project-id", pool_size=pool_size) assert len(client.transport._grpc_channel._pool) == pool_size next_channel.reset_mock() @@ -226,7 +230,9 @@ async def test_channel_pool_rotation(self): @pytest.mark.asyncio async def test_channel_pool_replace(self): - with mock.patch.object(asyncio, "sleep"): + import time + sleep_module = asyncio if self.is_async else time + with mock.patch.object(sleep_module, "sleep"): pool_size = 7 client = self._make_client(project="project-id", pool_size=pool_size) for replace_idx in range(pool_size): @@ -235,14 +241,16 @@ async def test_channel_pool_replace(self): ] grace_period = 9 with mock.patch.object( - type(client.transport._grpc_channel._pool[0]), "close" + type(client.transport._grpc_channel._pool[-1]), "close" ) as close: - new_channel = grpc.aio.insecure_channel("localhost:8080") + new_channel = client.transport.create_channel() await client.transport.replace_channel( replace_idx, grace=grace_period, new_channel=new_channel ) - close.assert_called_once_with(grace=grace_period) - close.assert_awaited_once() + close.assert_called_once() + if self.is_async: + close.assert_called_once_with(grace=grace_period) + close.assert_awaited_once() assert client.transport._grpc_channel._pool[replace_idx] == new_channel for i in range(pool_size): if i != replace_idx: @@ -271,20 +279,23 @@ async def test__start_background_channel_refresh_tasks_exist(self): @pytest.mark.asyncio @pytest.mark.parametrize("pool_size", [1, 3, 7]) async def test__start_background_channel_refresh(self, pool_size): + import concurrent.futures # should create background tasks for each channel - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - ping_and_warm = AsyncMock() - client._ping_and_warm_instances = ping_and_warm - client._start_background_channel_refresh() - assert len(client._channel_refresh_tasks) == pool_size - for task in client._channel_refresh_tasks: - assert isinstance(task, asyncio.Task) - await asyncio.sleep(0.1) - assert ping_and_warm.call_count == pool_size - for channel in client.transport._grpc_channel._pool: - ping_and_warm.assert_any_call(channel) + with mock.patch.object(self._get_target_class(), "_ping_and_warm_instances", AsyncMock()) as ping_and_warm: + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + client._start_background_channel_refresh() + assert len(client._channel_refresh_tasks) == pool_size + for task in client._channel_refresh_tasks: + if self.is_async: + assert isinstance(task, asyncio.Task) + else: + assert isinstance(task, concurrent.futures.Future) + await asyncio.sleep(0.1) + assert ping_and_warm.call_count == pool_size + for channel in client.transport._grpc_channel._pool: + ping_and_warm.assert_any_call(channel) await client.close() @pytest.mark.asyncio @@ -309,9 +320,15 @@ async def test__ping_and_warm_instances(self): test ping and warm with mocked asyncio.gather """ client_mock = mock.Mock() - with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: - # simulate gather by returning the same number of items as passed in - gather.side_effect = lambda *args, **kwargs: [None for _ in args] + gather_tuple = (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") + with mock.patch.object(*gather_tuple, AsyncMock()) as gather: + if self.is_async: + # simulate gather by returning the same number of items as passed in + # gather is expected to return None for each coroutine passed + gather.side_effect = lambda *args, **kwargs: [None for _ in args] + else: + # submit is expected to call the function passed, and return the result + gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] channel = mock.Mock() # test with no instances client_mock._active_instances = [] @@ -319,10 +336,8 @@ async def test__ping_and_warm_instances(self): client_mock, channel ) assert len(result) == 0 - gather.assert_called_once() - gather.assert_awaited_once() - assert not gather.call_args.args - assert gather.call_args.kwargs == {"return_exceptions": True} + if self.is_async: + assert gather.call_args.kwargs == {"return_exceptions": True} # test with instances client_mock._active_instances = [ (mock.Mock(), mock.Mock(), mock.Mock()) @@ -333,9 +348,14 @@ async def test__ping_and_warm_instances(self): client_mock, channel ) assert len(result) == 4 - gather.assert_called_once() - gather.assert_awaited_once() - assert len(gather.call_args.args) == 4 + if self.is_async: + gather.assert_called_once() + gather.assert_awaited_once() + # expect one arg for each instance + assert len(gather.call_args.args) == 4 + else: + # expect one call for each instance + assert gather.call_count == 4 # check grpc call arguments grpc_call_args = channel.unary_unary().call_args_list for idx, (_, kwargs) in enumerate(grpc_call_args): @@ -361,9 +381,16 @@ async def test__ping_and_warm_single_instance(self): should be able to call ping and warm with single instance """ client_mock = mock.Mock() - with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: - # simulate gather by returning the same number of items as passed in - gather.side_effect = lambda *args, **kwargs: [None for _ in args] + gather_tuple = (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") + with mock.patch.object(*gather_tuple, AsyncMock()) as gather: + gather.side_effect = lambda *args, **kwargs: [mock.Mock() for _ in args] + if self.is_async: + # simulate gather by returning the same number of items as passed in + # gather is expected to return None for each coroutine passed + gather.side_effect = lambda *args, **kwargs: [None for _ in args] + else: + # submit is expected to call the function passed, and return the result + gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] channel = mock.Mock() # test with large set of instances client_mock._active_instances = [mock.Mock()] * 100 @@ -403,11 +430,13 @@ async def test__manage_channel_first_sleep( self, refresh_interval, wait_time, expected_sleep ): # first sleep time should be `refresh_interval` seconds after client init + import threading import time - with mock.patch.object(time, "monotonic") as time: - time.return_value = 0 - with mock.patch.object(asyncio, "sleep") as sleep: + with mock.patch.object(time, "monotonic") as monotonic: + monotonic.return_value = 0 + sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = asyncio.CancelledError try: client = self._make_client(project="project-id") @@ -428,15 +457,20 @@ async def test__manage_channel_ping_and_warm(self): _manage channel should call ping and warm internally """ import time + import threading client_mock = mock.Mock() + if not self.is_async: + # make sure loop is entered + client_mock._is_closed.is_set.return_value = False client_mock._channel_init_time = time.monotonic() channel_list = [mock.Mock(), mock.Mock()] client_mock.transport.channels = channel_list new_channel = mock.Mock() client_mock.transport.grpc_channel._create_channel.return_value = new_channel # should ping an warm all new channels, and old channels if sleeping - with mock.patch.object(asyncio, "sleep"): + sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + with mock.patch.object(*sleep_tuple): # stop process after replace_channel is called client_mock.transport.replace_channel.side_effect = asyncio.CancelledError ping_and_warm = client_mock._ping_and_warm_instances = AsyncMock() @@ -481,26 +515,29 @@ async def test__manage_channel_sleeps( # make sure that sleeps work as expected import time import random + import threading channel_idx = 1 with mock.patch.object(random, "uniform") as uniform: uniform.side_effect = lambda min_, max_: min_ - with mock.patch.object(time, "time") as time: - time.return_value = 0 - with mock.patch.object(asyncio, "sleep") as sleep: + with mock.patch.object(time, "time") as time_mock: + time_mock.return_value = 0 + sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = [None for i in range(num_cycles - 1)] + [ asyncio.CancelledError ] - try: - client = self._make_client(project="project-id") - if refresh_interval is not None: - await client._manage_channel( - channel_idx, refresh_interval, refresh_interval - ) - else: - await client._manage_channel(channel_idx) - except asyncio.CancelledError: - pass + client = self._make_client(project="project-id") + with mock.patch.object(client.transport, "replace_channel"): + try: + if refresh_interval is not None: + await client._manage_channel( + channel_idx, refresh_interval, refresh_interval + ) + else: + await client._manage_channel(channel_idx) + except asyncio.CancelledError: + pass assert sleep.call_count == num_cycles total_sleep = sum([call[0][0] for call in sleep.call_args_list]) assert ( @@ -511,8 +548,10 @@ async def test__manage_channel_sleeps( @pytest.mark.asyncio async def test__manage_channel_random(self): import random + import threading - with mock.patch.object(asyncio, "sleep") as sleep: + sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + with mock.patch.object(*sleep_tuple) as sleep: with mock.patch.object(random, "uniform") as uniform: uniform.return_value = 0 try: @@ -527,10 +566,11 @@ async def test__manage_channel_random(self): uniform.side_effect = lambda min_, max_: min_ sleep.side_effect = [None, None, asyncio.CancelledError] try: - await client._manage_channel(0, min_val, max_val) + with mock.patch.object(client.transport, "replace_channel"): + await client._manage_channel(0, min_val, max_val) except asyncio.CancelledError: pass - assert uniform.call_count == 2 + assert uniform.call_count == 3 uniform_args = [call[0] for call in uniform.call_args_list] for found_min, found_max in uniform_args: assert found_min == min_val @@ -540,20 +580,19 @@ async def test__manage_channel_random(self): @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) async def test__manage_channel_refresh(self, num_cycles): # make sure that channels are properly refreshed - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - ) - from google.api_core import grpc_helpers_async + import threading expected_grace = 9 expected_refresh = 0.5 channel_idx = 1 - new_channel = grpc.aio.insecure_channel("localhost:8080") + grpc_lib = grpc.aio if self.is_async else grpc + new_channel = grpc_lib.insecure_channel("localhost:8080") with mock.patch.object( PooledBigtableGrpcAsyncIOTransport, "replace_channel" ) as replace_channel: - with mock.patch.object(asyncio, "sleep") as sleep: + sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = [None for i in range(num_cycles)] + [ asyncio.CancelledError ] @@ -561,7 +600,8 @@ async def test__manage_channel_refresh(self, num_cycles): grpc_helpers_async, "create_channel" ) as create_channel: create_channel.return_value = new_channel - client = self._make_client(project="project-id", use_emulator=False) + with mock.patch.object(self._get_target_class(), "_start_background_channel_refresh"): + client = self._make_client(project="project-id", use_emulator=False) create_channel.reset_mock() try: await client._manage_channel( @@ -873,9 +913,7 @@ async def test_get_table_arg_passthrough(self): All arguments passed in get_table should be sent to constructor """ async with self._make_client(project="project-id") as client: - with mock.patch( - "google.cloud.bigtable.data._async.client.TableAsync.__init__", - ) as mock_constructor: + with mock.patch.object(TestTableAsync._get_target_class(), "__init__") as mock_constructor: mock_constructor.return_value = None assert not client._active_instances expected_table_id = "table-id" @@ -957,10 +995,6 @@ async def test_multiple_pool_sizes(self): @pytest.mark.asyncio async def test_close(self): - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - ) - pool_size = 7 client = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False @@ -977,7 +1011,6 @@ async def test_close(self): close_mock.assert_awaited() for task in tasks_list: assert task.done() - assert task.cancelled() assert client._channel_refresh_tasks == [] @pytest.mark.asyncio @@ -1032,9 +1065,14 @@ class TestTableAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data._async.client import TableAsync + return TableAsync + + @pytest.mark.asyncio async def test_table_ctor(self): - from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable.data._async.client import _WarmedInstanceKey expected_table_id = "table-id" @@ -1049,7 +1087,7 @@ async def test_table_ctor(self): client = self._make_client() assert not client._active_instances - table = TableAsync( + table = self._get_target_class()( client, expected_instance_id, expected_table_id, @@ -1302,9 +1340,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ from google.cloud.bigtable.data import TableAsync profile = "profile" if include_app_profile else None - with mock.patch( - f"google.cloud.bigtable_v2.BigtableAsyncClient.{gapic_fn}", mock.AsyncMock() - ) as gapic_mock: + with mock.patch.object(BigtableAsyncClient, gapic_fn, mock.AsyncMock()) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") async with self._make_client() as client: table = TableAsync(client, "instance-id", "table-id", profile) diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py index 14b4d0f5e..9c0791bc7 100644 --- a/tests/unit/data/_sync/test_autogen.py +++ b/tests/unit/data/_sync/test_autogen.py @@ -15,824 +15,1067 @@ # This file is automatically generated by sync_surface_generator.py. Do not edit. +from __future__ import annotations from abc import ABC from unittest import mock import asyncio -import concurrent.futures +import grpc import mock import pytest +import re import time +from google.api_core import exceptions as core_exceptions +from google.api_core import grpc_helpers +from google.auth.credentials import AnonymousCredentials from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data import Table +from google.cloud.bigtable.data import mutations from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete -import google.api_core.exceptions +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledBigtableGrpcTransport, +) +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledChannel, +) import google.api_core.exceptions as core_exceptions -import google.api_core.retry -class TestMutationsBatcher(ABC): - def _get_target_class(self): - from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher +class TestBigtableDataClient(ABC): + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data._sync.client import BigtableDataClient - return MutationsBatcher + return BigtableDataClient - @staticmethod - def is_async(): - return False + @classmethod + def _make_client(cls, *args, use_emulator=True, **kwargs): + import os - def _make_one(self, table=None, **kwargs): - from google.api_core.exceptions import DeadlineExceeded - from google.api_core.exceptions import ServiceUnavailable - - if table is None: - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 10 - table.default_mutate_rows_retryable_errors = ( - DeadlineExceeded, - ServiceUnavailable, - ) - return self._get_target_class()(table, **kwargs) + env_mask = {} + if use_emulator: + env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" + import warnings - @staticmethod - def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation + warnings.filterwarnings("ignore", category=RuntimeWarning) + else: + kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) + kwargs["project"] = kwargs.get("project", "project-id") + with mock.patch.dict(os.environ, env_mask): + return cls._get_target_class()(*args, **kwargs) - def test_ctor_defaults(self): + @property + def is_async(self): + return False + + def test_ctor(self): + expected_project = "project-id" + expected_pool_size = 11 + expected_credentials = AnonymousCredentials() + client = self._make_client( + project="project-id", + pool_size=expected_pool_size, + credentials=expected_credentials, + use_emulator=False, + ) + time.sleep(0) + assert client.project == expected_project + assert len(client.transport._grpc_channel._pool) == expected_pool_size + assert not client._active_instances + assert len(client._channel_refresh_tasks) == expected_pool_size + assert client.transport._credentials == expected_credentials + client.close() + + def test_ctor_super_inits(self): + from google.cloud.client import ClientWithProject + from google.api_core import client_options as client_options_lib + + project = "project-id" + pool_size = 11 + credentials = AnonymousCredentials() + client_options = {"api_endpoint": "foo.bar:1234"} + options_parsed = client_options_lib.from_dict(client_options) + asyncio_portion = "_asyncio" if self.is_async else "" + transport_str = f"pooled_grpc{asyncio_portion}_{pool_size}" + with mock.patch.object(BigtableClient, "__init__") as bigtable_client_init: + bigtable_client_init.return_value = None + with mock.patch.object( + ClientWithProject, "__init__" + ) as client_project_init: + client_project_init.return_value = None + try: + self._make_client( + project=project, + pool_size=pool_size, + credentials=credentials, + client_options=options_parsed, + use_emulator=False, + ) + except AttributeError: + pass + assert bigtable_client_init.call_count == 1 + kwargs = bigtable_client_init.call_args[1] + assert kwargs["transport"] == transport_str + assert kwargs["credentials"] == credentials + assert kwargs["client_options"] == options_parsed + assert client_project_init.call_count == 1 + kwargs = client_project_init.call_args[1] + assert kwargs["project"] == project + assert kwargs["credentials"] == credentials + assert kwargs["client_options"] == options_parsed + + def test_ctor_dict_options(self): + from google.api_core.client_options import ClientOptions + + client_options = {"api_endpoint": "foo.bar:1234"} + with mock.patch.object(BigtableClient, "__init__") as bigtable_client_init: + try: + self._make_client(client_options=client_options) + except TypeError: + pass + bigtable_client_init.assert_called_once() + kwargs = bigtable_client_init.call_args[1] + called_options = kwargs["client_options"] + assert called_options.api_endpoint == "foo.bar:1234" + assert isinstance(called_options, ClientOptions) with mock.patch.object( - self._get_target_class(), - "_timer_routine", - return_value=concurrent.futures.Future(), - ) as flush_timer_mock: - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = [Exception] - with self._make_one(table) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._max_mutation_count == 100000 - assert instance._flow_control._max_mutation_bytes == 104857600 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert ( - instance._operation_timeout - == table.default_mutate_rows_operation_timeout - ) - assert ( - instance._attempt_timeout - == table.default_mutate_rows_attempt_timeout - ) - assert ( - instance._retryable_errors - == table.default_mutate_rows_retryable_errors + self._get_target_class(), "_start_background_channel_refresh" + ) as start_background_refresh: + client = self._make_client( + client_options=client_options, use_emulator=False + ) + start_background_refresh.assert_called_once() + client.close() + + def test_veneer_grpc_headers(self): + client_component = "data-async" if self.is_async else "data" + VENEER_HEADER_REGEX = re.compile( + "gapic\\/[0-9]+\\.[\\w.-]+ gax\\/[0-9]+\\.[\\w.-]+ gccl\\/[0-9]+\\.[\\w.-]+-" + + client_component + + " gl-python\\/[0-9]+\\.[\\w.-]+ grpc\\/[0-9]+\\.[\\w.-]+" + ) + if self.is_async: + patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") + else: + patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") + with patch as gapic_mock: + client = self._make_client(project="project-id") + wrapped_call_list = gapic_mock.call_args_list + assert len(wrapped_call_list) > 0 + for call in wrapped_call_list: + client_info = call.kwargs["client_info"] + assert client_info is not None, f"{call} has no client_info" + wrapped_user_agent_sorted = " ".join( + sorted(client_info.to_user_agent().split(" ")) ) - time.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == 5 - assert isinstance(instance._flush_timer, concurrent.futures.Future) + assert VENEER_HEADER_REGEX.match( + wrapped_user_agent_sorted + ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" + client.close() - def test_ctor_explicit(self): - """Test with explicit parameters""" + def test_channel_pool_creation(self): + pool_size = 14 with mock.patch.object( - self._get_target_class(), - "_timer_routine", - return_value=concurrent.futures.Future(), - ) as flush_timer_mock: - table = mock.Mock() - flush_interval = 20 - flush_limit_count = 17 - flush_limit_bytes = 19 - flow_control_max_mutation_count = 1001 - flow_control_max_bytes = 12 - operation_timeout = 11 - attempt_timeout = 2 - retryable_errors = [Exception] - with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - flow_control_max_mutation_count=flow_control_max_mutation_count, - flow_control_max_bytes=flow_control_max_bytes, - batch_operation_timeout=operation_timeout, - batch_attempt_timeout=attempt_timeout, - batch_retryable_errors=retryable_errors, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert ( - instance._flow_control._max_mutation_count - == flow_control_max_mutation_count - ) + grpc_helpers, "create_channel", mock.Mock() + ) as create_channel: + client = self._make_client(project="project-id", pool_size=pool_size) + assert create_channel.call_count == pool_size + client.close() + client = self._make_client(project="project-id", pool_size=pool_size) + pool_list = list(client.transport._grpc_channel._pool) + pool_set = set(client.transport._grpc_channel._pool) + assert len(pool_list) == len(pool_set) + client.close() + + def test_channel_pool_rotation(self): + pool_size = 7 + with mock.patch.object(PooledChannel, "next_channel") as next_channel: + client = self._make_client(project="project-id", pool_size=pool_size) + assert len(client.transport._grpc_channel._pool) == pool_size + next_channel.reset_mock() + with mock.patch.object( + type(client.transport._grpc_channel._pool[0]), "unary_unary" + ) as unary_unary: + channel_next = None + for i in range(pool_size): + channel_last = channel_next + channel_next = client.transport.grpc_channel._pool[i] + assert channel_last != channel_next + next_channel.return_value = channel_next + client.transport.ping_and_warm() + assert next_channel.call_count == i + 1 + unary_unary.assert_called_once() + unary_unary.reset_mock() + client.close() + + def test_channel_pool_replace(self): + import time + + sleep_module = asyncio if self.is_async else time + with mock.patch.object(sleep_module, "sleep"): + pool_size = 7 + client = self._make_client(project="project-id", pool_size=pool_size) + for replace_idx in range(pool_size): + start_pool = [ + channel for channel in client.transport._grpc_channel._pool + ] + grace_period = 9 + with mock.patch.object( + type(client.transport._grpc_channel._pool[-1]), "close" + ) as close: + new_channel = client.transport.create_channel() + client.transport.replace_channel( + replace_idx, grace=grace_period, new_channel=new_channel + ) + close.assert_called_once() + if self.is_async: + close.assert_called_once_with(grace=grace_period) + close.assert_called_once() + assert client.transport._grpc_channel._pool[replace_idx] == new_channel + for i in range(pool_size): + if i != replace_idx: + assert client.transport._grpc_channel._pool[i] == start_pool[i] + else: + assert client.transport._grpc_channel._pool[i] != start_pool[i] + client.close() + + def test__start_background_channel_refresh_tasks_exist(self): + client = self._make_client(project="project-id", use_emulator=False) + assert len(client._channel_refresh_tasks) > 0 + with mock.patch.object(asyncio, "create_task") as create_task: + client._start_background_channel_refresh() + create_task.assert_not_called() + client.close() + + @pytest.mark.parametrize("pool_size", [1, 3, 7]) + def test__start_background_channel_refresh(self, pool_size): + import concurrent.futures + + with mock.patch.object( + self._get_target_class(), "_ping_and_warm_instances", mock.Mock() + ) as ping_and_warm: + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + client._start_background_channel_refresh() + assert len(client._channel_refresh_tasks) == pool_size + for task in client._channel_refresh_tasks: + if self.is_async: + assert isinstance(task, asyncio.Task) + else: + assert isinstance(task, concurrent.futures.Future) + time.sleep(0.1) + assert ping_and_warm.call_count == pool_size + for channel in client.transport._grpc_channel._pool: + ping_and_warm.assert_any_call(channel) + client.close() + + def test__ping_and_warm_instances(self): + """test ping and warm with mocked asyncio.gather""" + client_mock = mock.Mock() + gather_tuple = ( + (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") + ) + with mock.patch.object(*gather_tuple, mock.Mock()) as gather: + if self.is_async: + gather.side_effect = lambda *args, **kwargs: [None for _ in args] + else: + gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] + channel = mock.Mock() + client_mock._active_instances = [] + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel + ) + assert len(result) == 0 + if self.is_async: + assert gather.call_args.kwargs == {"return_exceptions": True} + client_mock._active_instances = [ + (mock.Mock(), mock.Mock(), mock.Mock()) + ] * 4 + gather.reset_mock() + channel.reset_mock() + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel + ) + assert len(result) == 4 + if self.is_async: + gather.assert_called_once() + gather.assert_called_once() + assert len(gather.call_args.args) == 4 + else: + assert gather.call_count == 4 + grpc_call_args = channel.unary_unary().call_args_list + for idx, (_, kwargs) in enumerate(grpc_call_args): + ( + expected_instance, + expected_table, + expected_app_profile, + ) = client_mock._active_instances[idx] + request = kwargs["request"] + assert request["name"] == expected_instance + assert request["app_profile_id"] == expected_app_profile + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" assert ( - instance._flow_control._max_mutation_bytes == flow_control_max_bytes + metadata[0][1] + == f"name={expected_instance}&app_profile_id={expected_app_profile}" ) - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert instance._operation_timeout == operation_timeout - assert instance._attempt_timeout == attempt_timeout - assert instance._retryable_errors == retryable_errors - time.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == flush_interval - assert isinstance(instance._flush_timer, concurrent.futures.Future) - - def test_ctor_no_flush_limits(self): - """Test with None for flush limits""" - with mock.patch.object( - self._get_target_class(), - "_timer_routine", - return_value=concurrent.futures.Future(), - ) as flush_timer_mock: - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = () - flush_interval = None - flush_limit_count = None - flush_limit_bytes = None - with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._staged_entries == [] - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - time.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] is None - assert isinstance(instance._flush_timer, concurrent.futures.Future) - - def test_ctor_invalid_values(self): - """Test that timeout values are positive, and fit within expected limits""" - with pytest.raises(ValueError) as e: - self._make_one(batch_operation_timeout=-1) - assert "operation_timeout must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - self._make_one(batch_attempt_timeout=-1) - assert "attempt_timeout must be greater than 0" in str(e.value) - - def test_default_argument_consistency(self): - """ - We supply default arguments in MutationsBatcherAsync.__init__, and in - table.mutations_batcher. Make sure any changes to defaults are applied to - both places - """ - import inspect - get_batcher_signature = dict( - inspect.signature(Table.mutations_batcher).parameters - ) - get_batcher_signature.pop("self") - batcher_init_signature = dict( - inspect.signature(self._get_target_class()).parameters + def test__ping_and_warm_single_instance(self): + """should be able to call ping and warm with single instance""" + client_mock = mock.Mock() + gather_tuple = ( + (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") ) - batcher_init_signature.pop("table") - assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) - assert len(get_batcher_signature) == 8 - assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) - for arg_name in get_batcher_signature.keys(): + with mock.patch.object(*gather_tuple, mock.Mock()) as gather: + gather.side_effect = lambda *args, **kwargs: [mock.Mock() for _ in args] + if self.is_async: + gather.side_effect = lambda *args, **kwargs: [None for _ in args] + else: + gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] + channel = mock.Mock() + client_mock._active_instances = [mock.Mock()] * 100 + test_key = ("test-instance", "test-table", "test-app-profile") + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel, test_key + ) + assert len(result) == 1 + grpc_call_args = channel.unary_unary().call_args_list + assert len(grpc_call_args) == 1 + kwargs = grpc_call_args[0][1] + request = kwargs["request"] + assert request["name"] == "test-instance" + assert request["app_profile_id"] == "test-app-profile" + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" assert ( - get_batcher_signature[arg_name].default - == batcher_init_signature[arg_name].default + metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" ) - @pytest.mark.parametrize("input_val", [None, 0, -1]) - def test__start_flush_timer_w_empty_input(self, input_val): - """Empty/invalid timer should return immediately""" - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: - with self._make_one() as instance: - if self.is_async(): - (sleep_obj, sleep_method) = (asyncio, "wait_for") - else: - (sleep_obj, sleep_method) = (instance._closed, "wait") - with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: - result = instance._timer_routine(input_val) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 - assert result is None - - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test__start_flush_timer_call_when_closed(self): - """closed batcher's timer should return immediately""" - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: - with self._make_one() as instance: - instance.close() - flush_mock.reset_mock() - if self.is_async(): - (sleep_obj, sleep_method) = (asyncio, "wait_for") - else: - (sleep_obj, sleep_method) = (instance._closed, "wait") - with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: - instance._timer_routine(10) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 - - @pytest.mark.parametrize("num_staged", [0, 1, 10]) - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test__flush_timer(self, num_staged): - """Timer should continue to call _schedule_flush in a loop""" - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: - expected_sleep = 12 - with self._make_one(flush_interval=expected_sleep) as instance: - loop_num = 3 - instance._staged_entries = [mock.Mock()] * num_staged - if self.is_async(): - (sleep_obj, sleep_method) = (asyncio, "wait_for") - else: - (sleep_obj, sleep_method) = (instance._closed, "wait") - with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: - sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] - with pytest.raises(TabError): - self._get_target_class()._timer_routine( - instance, expected_sleep - ) - instance._flush_timer = concurrent.futures.Future() - assert sleep_mock.call_count == loop_num + 1 - sleep_kwargs = sleep_mock.call_args[1] - assert sleep_kwargs["timeout"] == expected_sleep - assert flush_mock.call_count == (0 if num_staged == 0 else loop_num) - - def test__flush_timer_close(self): - """Timer should continue terminate after close""" - with mock.patch.object(self._get_target_class(), "_schedule_flush"): - with self._make_one() as instance: - with mock.patch("asyncio.sleep"): - time.sleep(0.5) - assert instance._flush_timer.done() is False - instance.close() - time.sleep(0.1) - assert instance._flush_timer.done() is True - - def test_append_closed(self): - """Should raise exception""" - instance = self._make_one() - instance.close() - with pytest.raises(RuntimeError): - instance.append(mock.Mock()) - - def test_append_wrong_mutation(self): - """ - Mutation objects should raise an exception. - Only support RowMutationEntry - """ - from google.cloud.bigtable.data.mutations import DeleteAllFromRow - - with self._make_one() as instance: - expected_error = "invalid mutation type: DeleteAllFromRow. Only RowMutationEntry objects are supported by batcher" - with pytest.raises(ValueError) as e: - instance.append(DeleteAllFromRow()) - assert str(e.value) == expected_error - - def test_append_outside_flow_limits(self): - """entries larger than mutation limits are still processed""" - with self._make_one( - flow_control_max_mutation_count=1, flow_control_max_bytes=1 - ) as instance: - oversized_entry = self._make_mutation(count=0, size=2) - instance.append(oversized_entry) - assert instance._staged_entries == [oversized_entry] - assert instance._staged_count == 0 - assert instance._staged_bytes == 2 - instance._staged_entries = [] - with self._make_one( - flow_control_max_mutation_count=1, flow_control_max_bytes=1 - ) as instance: - overcount_entry = self._make_mutation(count=2, size=0) - instance.append(overcount_entry) - assert instance._staged_entries == [overcount_entry] - assert instance._staged_count == 2 - assert instance._staged_bytes == 0 - instance._staged_entries = [] - - def test_append_flush_runs_after_limit_hit(self): - """ - If the user appends a bunch of entries above the flush limits back-to-back, - it should still flush in a single task - """ - with mock.patch.object( - self._get_target_class(), "_execute_mutate_rows" - ) as op_mock: - with self._make_one(flush_limit_bytes=100) as instance: - - def mock_call(*args, **kwargs): - return [] - - op_mock.side_effect = mock_call - instance.append(self._make_mutation(size=99)) - num_entries = 10 - for _ in range(num_entries): - instance.append(self._make_mutation(size=1)) - instance._wait_for_batch_results(*instance._flush_jobs) - assert op_mock.call_count == 1 - sent_batch = op_mock.call_args[0][0] - assert len(sent_batch) == 2 - assert len(instance._staged_entries) == num_entries - 1 - @pytest.mark.parametrize( - "flush_count,flush_bytes,mutation_count,mutation_bytes,expect_flush", - [ - (10, 10, 1, 1, False), - (10, 10, 9, 9, False), - (10, 10, 10, 1, True), - (10, 10, 1, 10, True), - (10, 10, 10, 10, True), - (1, 1, 10, 10, True), - (1, 1, 0, 0, False), - ], + "refresh_interval, wait_time, expected_sleep", + [(0, 0, 0), (0, 1, 0), (10, 0, 10), (10, 5, 5), (10, 10, 0), (10, 15, 0)], ) - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test_append( - self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush + def test__manage_channel_first_sleep( + self, refresh_interval, wait_time, expected_sleep ): - """test appending different mutations, and checking if it causes a flush""" - with self._make_one( - flush_limit_mutation_count=flush_count, flush_limit_bytes=flush_bytes - ) as instance: - assert instance._staged_count == 0 - assert instance._staged_bytes == 0 - assert instance._staged_entries == [] - mutation = self._make_mutation(count=mutation_count, size=mutation_bytes) - with mock.patch.object(instance, "_schedule_flush") as flush_mock: - instance.append(mutation) - assert flush_mock.call_count == bool(expect_flush) - assert instance._staged_count == mutation_count - assert instance._staged_bytes == mutation_bytes - assert instance._staged_entries == [mutation] - instance._staged_entries = [] - - def test_append_multiple_sequentially(self): - """Append multiple mutations""" - with self._make_one( - flush_limit_mutation_count=8, flush_limit_bytes=8 - ) as instance: - assert instance._staged_count == 0 - assert instance._staged_bytes == 0 - assert instance._staged_entries == [] - mutation = self._make_mutation(count=2, size=3) - with mock.patch.object(instance, "_schedule_flush") as flush_mock: - instance.append(mutation) - assert flush_mock.call_count == 0 - assert instance._staged_count == 2 - assert instance._staged_bytes == 3 - assert len(instance._staged_entries) == 1 - instance.append(mutation) - assert flush_mock.call_count == 0 - assert instance._staged_count == 4 - assert instance._staged_bytes == 6 - assert len(instance._staged_entries) == 2 - instance.append(mutation) - assert flush_mock.call_count == 1 - assert instance._staged_count == 6 - assert instance._staged_bytes == 9 - assert len(instance._staged_entries) == 3 - instance._staged_entries = [] - - def test_flush_flow_control_concurrent_requests(self): - """requests should happen in parallel if flow control breaks up single flush into batches""" + import threading import time - num_calls = 10 - fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] - with self._make_one(flow_control_max_mutation_count=1) as instance: - with mock.patch.object( - instance, "_execute_mutate_rows", mock.Mock() - ) as op_mock: - - def mock_call(*args, **kwargs): - time.sleep(0.1) - return [] - - op_mock.side_effect = mock_call - start_time = time.monotonic() - instance._staged_entries = fake_mutations - instance._schedule_flush() - time.sleep(0.01) - for i in range(num_calls): - instance._flow_control.remove_from_flow( - [self._make_mutation(count=1)] - ) - time.sleep(0.01) - instance._wait_for_batch_results(*instance._flush_jobs) - duration = time.monotonic() - start_time - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert duration < 0.5 - assert op_mock.call_count == num_calls - - def test_schedule_flush_no_mutations(self): - """schedule flush should return None if no staged mutations""" - with self._make_one() as instance: - with mock.patch.object(instance, "_flush_internal") as flush_mock: - for i in range(3): - assert instance._schedule_flush() is None - assert flush_mock.call_count == 0 - - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test_schedule_flush_with_mutations(self): - """if new mutations exist, should add a new flush task to _flush_jobs""" - with self._make_one() as instance: - with mock.patch.object(instance, "_flush_internal") as flush_mock: - if not self.is_async(): - flush_mock.side_effect = lambda x: time.sleep(0.1) - for i in range(1, 4): - mutation = mock.Mock() - instance._staged_entries = [mutation] - instance._schedule_flush() - assert instance._staged_entries == [] - time.sleep(0) - assert instance._staged_entries == [] - assert instance._staged_count == 0 - assert instance._staged_bytes == 0 - assert flush_mock.call_count == 1 - flush_mock.reset_mock() + with mock.patch.object(time, "monotonic") as monotonic: + monotonic.return_value = 0 + sleep_tuple = ( + (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: + sleep.side_effect = asyncio.CancelledError + try: + client = self._make_client(project="project-id") + client._channel_init_time = -wait_time + client._manage_channel(0, refresh_interval, refresh_interval) + except asyncio.CancelledError: + pass + sleep.assert_called_once() + call_time = sleep.call_args[0][0] + assert ( + abs(call_time - expected_sleep) < 0.1 + ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" + client.close() - def test__flush_internal(self): - """ - _flush_internal should: - - await previous flush call - - delegate batching to _flow_control - - call _execute_mutate_rows on each batch - - update self.exceptions and self._entries_processed_since_last_raise - """ - num_entries = 10 - with self._make_one() as instance: - with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: + def test__manage_channel_ping_and_warm(self): + """_manage channel should call ping and warm internally""" + import time + import threading + + client_mock = mock.Mock() + if not self.is_async: + client_mock._is_closed.is_set.return_value = False + client_mock._channel_init_time = time.monotonic() + channel_list = [mock.Mock(), mock.Mock()] + client_mock.transport.channels = channel_list + new_channel = mock.Mock() + client_mock.transport.grpc_channel._create_channel.return_value = new_channel + sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + with mock.patch.object(*sleep_tuple): + client_mock.transport.replace_channel.side_effect = asyncio.CancelledError + ping_and_warm = client_mock._ping_and_warm_instances = mock.Mock() + try: + channel_idx = 1 + self._get_target_class()._manage_channel(client_mock, channel_idx, 10) + except asyncio.CancelledError: + pass + assert ping_and_warm.call_count == 2 + assert client_mock.transport.replace_channel.call_count == 1 + old_channel = channel_list[channel_idx] + assert old_channel != new_channel + called_with = [call[0][0] for call in ping_and_warm.call_args_list] + assert old_channel in called_with + assert new_channel in called_with + ping_and_warm.reset_mock() + try: + self._get_target_class()._manage_channel(client_mock, 0, 0, 0) + except asyncio.CancelledError: + pass + ping_and_warm.assert_called_once_with(new_channel) + + @pytest.mark.parametrize( + "refresh_interval, num_cycles, expected_sleep", + [(None, 1, 60 * 35), (10, 10, 100), (10, 1, 10)], + ) + def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sleep): + import time + import random + import threading + + channel_idx = 1 + with mock.patch.object(random, "uniform") as uniform: + uniform.side_effect = lambda min_, max_: min_ + with mock.patch.object(time, "time") as time_mock: + time_mock.return_value = 0 + sleep_tuple = ( + (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: + sleep.side_effect = [None for i in range(num_cycles - 1)] + [ + asyncio.CancelledError + ] + client = self._make_client(project="project-id") + with mock.patch.object(client.transport, "replace_channel"): + try: + if refresh_interval is not None: + client._manage_channel( + channel_idx, refresh_interval, refresh_interval + ) + else: + client._manage_channel(channel_idx) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + total_sleep = sum([call[0][0] for call in sleep.call_args_list]) + assert ( + abs(total_sleep - expected_sleep) < 0.1 + ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" + client.close() + + def test__manage_channel_random(self): + import random + import threading + + sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + with mock.patch.object(*sleep_tuple) as sleep: + with mock.patch.object(random, "uniform") as uniform: + uniform.return_value = 0 + try: + uniform.side_effect = asyncio.CancelledError + client = self._make_client(project="project-id", pool_size=1) + except asyncio.CancelledError: + uniform.side_effect = None + uniform.reset_mock() + sleep.reset_mock() + min_val = 200 + max_val = 205 + uniform.side_effect = lambda min_, max_: min_ + sleep.side_effect = [None, None, asyncio.CancelledError] + try: + with mock.patch.object(client.transport, "replace_channel"): + client._manage_channel(0, min_val, max_val) + except asyncio.CancelledError: + pass + assert uniform.call_count == 3 + uniform_args = [call[0] for call in uniform.call_args_list] + for found_min, found_max in uniform_args: + assert found_min == min_val + assert found_max == max_val + + @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) + def test__manage_channel_refresh(self, num_cycles): + import threading + + expected_grace = 9 + expected_refresh = 0.5 + channel_idx = 1 + grpc_lib = grpc.aio if self.is_async else grpc + new_channel = grpc_lib.insecure_channel("localhost:8080") + with mock.patch.object( + PooledBigtableGrpcTransport, "replace_channel" + ) as replace_channel: + sleep_tuple = ( + (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: + sleep.side_effect = [None for i in range(num_cycles)] + [ + asyncio.CancelledError + ] with mock.patch.object( - instance._flow_control, "add_to_flow" - ) as flow_mock: - - def gen(x): - yield x - - flow_mock.side_effect = lambda x: gen(x) - mutations = [self._make_mutation(count=1, size=1)] * num_entries - instance._flush_internal(mutations) - assert instance._entries_processed_since_last_raise == num_entries - assert execute_mock.call_count == 1 - assert flow_mock.call_count == 1 - instance._oldest_exceptions.clear() - instance._newest_exceptions.clear() - - def test_flush_clears_job_list(self): - """ - a job should be added to _flush_jobs when _schedule_flush is called, - and removed when it completes - """ - with self._make_one() as instance: - with mock.patch.object( - instance, "_flush_internal", mock.Mock() - ) as flush_mock: - if not self.is_async(): - flush_mock.side_effect = lambda x: time.sleep(0.1) - mutations = [self._make_mutation(count=1, size=1)] - instance._staged_entries = mutations - assert instance._flush_jobs == set() - new_job = instance._schedule_flush() - assert instance._flush_jobs == {new_job} - if self.is_async(): - new_job - else: - new_job.result() - assert instance._flush_jobs == set() + grpc_helpers, "create_channel" + ) as create_channel: + create_channel.return_value = new_channel + with mock.patch.object( + self._get_target_class(), "_start_background_channel_refresh" + ): + client = self._make_client( + project="project-id", use_emulator=False + ) + create_channel.reset_mock() + try: + client._manage_channel( + channel_idx, + refresh_interval_min=expected_refresh, + refresh_interval_max=expected_refresh, + grace_period=expected_grace, + ) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + 1 + assert create_channel.call_count == num_cycles + assert replace_channel.call_count == num_cycles + for call in replace_channel.call_args_list: + (args, kwargs) = call + assert args[0] == channel_idx + assert kwargs["grace"] == expected_grace + assert kwargs["new_channel"] == new_channel + client.close() + + def test__register_instance(self): + """test instance registration""" + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: f"prefix/{b}" + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_tasks = [] + client_mock._start_background_channel_refresh.side_effect = ( + lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + ) + mock_channels = [mock.Mock() for i in range(5)] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = mock.Mock() + table_mock = mock.Mock() + self._get_target_class()._register_instance( + client_mock, "instance-1", table_mock + ) + assert client_mock._start_background_channel_refresh.call_count == 1 + expected_key = ( + "prefix/instance-1", + table_mock.table_name, + table_mock.app_profile_id, + ) + assert len(active_instances) == 1 + assert expected_key == tuple(list(active_instances)[0]) + assert len(instance_owners) == 1 + assert expected_key == tuple(list(instance_owners)[0]) + assert client_mock._channel_refresh_tasks + table_mock2 = mock.Mock() + self._get_target_class()._register_instance( + client_mock, "instance-2", table_mock2 + ) + assert client_mock._start_background_channel_refresh.call_count == 1 + assert client_mock._ping_and_warm_instances.call_count == len(mock_channels) + for channel in mock_channels: + assert channel in [ + call[0][0] + for call in client_mock._ping_and_warm_instances.call_args_list + ] + assert len(active_instances) == 2 + assert len(instance_owners) == 2 + expected_key2 = ( + "prefix/instance-2", + table_mock2.table_name, + table_mock2.app_profile_id, + ) + assert any( + [ + expected_key2 == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) + assert any( + [ + expected_key2 == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) @pytest.mark.parametrize( - "num_starting,num_new_errors,expected_total_errors", + "insert_instances,expected_active,expected_owner_keys", [ - (0, 0, 0), - (0, 1, 1), - (0, 2, 2), - (1, 0, 1), - (1, 1, 2), - (10, 2, 12), - (10, 20, 20), + ([("i", "t", None)], [("i", "t", None)], [("i", "t", None)]), + ([("i", "t", "p")], [("i", "t", "p")], [("i", "t", "p")]), + ([("1", "t", "p"), ("1", "t", "p")], [("1", "t", "p")], [("1", "t", "p")]), + ( + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + ), ], ) - def test__flush_internal_with_errors( - self, num_starting, num_new_errors, expected_total_errors + def test__register_instance_state( + self, insert_instances, expected_active, expected_owner_keys ): - """errors returned from _execute_mutate_rows should be added to internal exceptions""" - from google.cloud.bigtable.data import exceptions - - num_entries = 10 - expected_errors = [ - exceptions.FailedMutationEntryError(mock.Mock(), mock.Mock(), ValueError()) - ] * num_new_errors - with self._make_one() as instance: - instance._oldest_exceptions = [mock.Mock()] * num_starting - with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: - execute_mock.return_value = expected_errors - with mock.patch.object( - instance._flow_control, "add_to_flow" - ) as flow_mock: - - def gen(x): - yield x - - flow_mock.side_effect = lambda x: gen(x) - mutations = [self._make_mutation(count=1, size=1)] * num_entries - instance._flush_internal(mutations) - assert instance._entries_processed_since_last_raise == num_entries - assert execute_mock.call_count == 1 - assert flow_mock.call_count == 1 - found_exceptions = instance._oldest_exceptions + list( - instance._newest_exceptions - ) - assert len(found_exceptions) == expected_total_errors - for i in range(num_starting, expected_total_errors): - assert found_exceptions[i] == expected_errors[i - num_starting] - assert found_exceptions[i].index is None - instance._oldest_exceptions.clear() - instance._newest_exceptions.clear() - - def _mock_gapic_return(self, num=5): - from google.cloud.bigtable_v2.types import MutateRowsResponse - from google.rpc import status_pb2 - - def gen(num): - for i in range(num): - entry = MutateRowsResponse.Entry( - index=i, status=status_pb2.Status(code=0) + """test that active_instances and instance_owners are updated as expected""" + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: b + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_tasks = [] + client_mock._start_background_channel_refresh.side_effect = ( + lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + ) + mock_channels = [mock.Mock() for i in range(5)] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = mock.Mock() + table_mock = mock.Mock() + for instance, table, profile in insert_instances: + table_mock.table_name = table + table_mock.app_profile_id = profile + self._get_target_class()._register_instance( + client_mock, instance, table_mock + ) + assert len(active_instances) == len(expected_active) + assert len(instance_owners) == len(expected_owner_keys) + for expected in expected_active: + assert any( + [ + expected == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) + for expected in expected_owner_keys: + assert any( + [ + expected == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) + + def test__remove_instance_registration(self): + client = self._make_client(project="project-id") + table = mock.Mock() + client._register_instance("instance-1", table) + client._register_instance("instance-2", table) + assert len(client._active_instances) == 2 + assert len(client._instance_owners.keys()) == 2 + instance_1_path = client._gapic_client.instance_path( + client.project, "instance-1" + ) + instance_1_key = (instance_1_path, table.table_name, table.app_profile_id) + instance_2_path = client._gapic_client.instance_path( + client.project, "instance-2" + ) + instance_2_key = (instance_2_path, table.table_name, table.app_profile_id) + assert len(client._instance_owners[instance_1_key]) == 1 + assert list(client._instance_owners[instance_1_key])[0] == id(table) + assert len(client._instance_owners[instance_2_key]) == 1 + assert list(client._instance_owners[instance_2_key])[0] == id(table) + success = client._remove_instance_registration("instance-1", table) + assert success + assert len(client._active_instances) == 1 + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 1 + assert client._active_instances == {instance_2_key} + success = client._remove_instance_registration("fake-key", table) + assert not success + assert len(client._active_instances) == 1 + client.close() + + def test__multiple_table_registration(self): + """ + registering with multiple tables with the same key should + add multiple owners to instance_owners, but only keep one copy + of shared key in active_instances + """ + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + + with self._make_client(project="project-id") as client: + with client.get_table("instance_1", "table_1") as table_1: + instance_1_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id ) - yield MutateRowsResponse(entries=[entry]) + assert len(client._instance_owners[instance_1_key]) == 1 + assert len(client._active_instances) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + with client.get_table("instance_1", "table_1") as table_2: + assert len(client._instance_owners[instance_1_key]) == 2 + assert len(client._active_instances) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + with client.get_table("instance_1", "table_3") as table_3: + instance_3_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_3_key = _WarmedInstanceKey( + instance_3_path, table_3.table_name, table_3.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 2 + assert len(client._instance_owners[instance_3_key]) == 1 + assert len(client._active_instances) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + assert id(table_3) in client._instance_owners[instance_3_key] + assert len(client._active_instances) == 1 + assert instance_1_key in client._active_instances + assert id(table_2) not in client._instance_owners[instance_1_key] + assert len(client._active_instances) == 0 + assert instance_1_key not in client._active_instances + assert len(client._instance_owners[instance_1_key]) == 0 + + def test__multiple_instance_registration(self): + """ + registering with multiple instance keys should update the key + in instance_owners and active_instances + """ + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey - return gen(num) + with self._make_client(project="project-id") as client: + with client.get_table("instance_1", "table_1") as table_1: + with client.get_table("instance_2", "table_2") as table_2: + instance_1_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id + ) + instance_2_path = client._gapic_client.instance_path( + client.project, "instance_2" + ) + instance_2_key = _WarmedInstanceKey( + instance_2_path, table_2.table_name, table_2.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 1 + assert len(client._instance_owners[instance_2_key]) == 1 + assert len(client._active_instances) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_2_key] + assert len(client._active_instances) == 1 + assert instance_1_key in client._active_instances + assert len(client._instance_owners[instance_2_key]) == 0 + assert len(client._instance_owners[instance_1_key]) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + assert len(client._active_instances) == 0 + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 0 + + def test_get_table(self): + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + + client = self._make_client(project="project-id") + assert not client._active_instances + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + table = client.get_table( + expected_instance_id, expected_table_id, expected_app_profile_id + ) + time.sleep(0) + assert isinstance(table, Table) + assert table.table_id == expected_table_id + assert ( + table.table_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert table.instance_id == expected_instance_id + assert ( + table.instance_name + == f"projects/{client.project}/instances/{expected_instance_id}" + ) + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + client.close() - def test_timer_flush_end_to_end(self): - """Flush should automatically trigger after flush_interval""" - num_nutations = 10 - mutations = [self._make_mutation(count=2, size=2)] * num_nutations - with self._make_one(flush_interval=0.05) as instance: - instance._table.default_operation_timeout = 10 - instance._table.default_attempt_timeout = 9 + def test_get_table_arg_passthrough(self): + """All arguments passed in get_table should be sent to constructor""" + with self._make_client(project="project-id") as client: with mock.patch.object( - instance._table.client._gapic_client, "mutate_rows" - ) as gapic_mock: - gapic_mock.side_effect = ( - lambda *args, **kwargs: self._mock_gapic_return(num_nutations) + TestTable._get_target_class(), "__init__" + ) as mock_constructor: + mock_constructor.return_value = None + assert not client._active_instances + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_args = (1, "test", {"test": 2}) + expected_kwargs = {"hello": "world", "test": 2} + client.get_table( + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, + ) + mock_constructor.assert_called_once_with( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, ) - for m in mutations: - instance.append(m) - assert instance._entries_processed_since_last_raise == 0 - time.sleep(0.1) - assert instance._entries_processed_since_last_raise == num_nutations - - def test__execute_mutate_rows(self): - if self.is_async(): - mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" - else: - mutate_path = "_sync._mutate_rows._MutateRowsOperation" - with mock.patch(f"google.cloud.bigtable.data.{mutate_path}") as mutate_rows: - mutate_rows.return_value = mock.Mock() - start_operation = mutate_rows().start - table = mock.Mock() - table.table_name = "test-table" - table.app_profile_id = "test-app-profile" - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - with self._make_one(table) as instance: - batch = [self._make_mutation()] - result = instance._execute_mutate_rows(batch) - assert start_operation.call_count == 1 - (args, kwargs) = mutate_rows.call_args - assert args[0] == table.client._gapic_client - assert args[1] == table - assert args[2] == batch - kwargs["operation_timeout"] == 17 - kwargs["attempt_timeout"] == 13 - assert result == [] - - def test__execute_mutate_rows_returns_errors(self): - """Errors from operation should be retruned as list""" - from google.cloud.bigtable.data.exceptions import ( - MutationsExceptionGroup, - FailedMutationEntryError, - ) - if self.is_async(): - mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" - else: - mutate_path = "_sync._mutate_rows._MutateRowsOperation" - with mock.patch( - f"google.cloud.bigtable.data.{mutate_path}.start" - ) as mutate_rows: - err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) - err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) - mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - with self._make_one(table) as instance: - batch = [self._make_mutation()] - result = instance._execute_mutate_rows(batch) - assert len(result) == 2 - assert result[0] == err1 - assert result[1] == err2 - assert result[0].index is None - assert result[1].index is None - - def test__raise_exceptions(self): - """Raise exceptions and reset error state""" - from google.cloud.bigtable.data import exceptions - - expected_total = 1201 - expected_exceptions = [RuntimeError("mock")] * 3 - with self._make_one() as instance: - instance._oldest_exceptions = expected_exceptions - instance._entries_processed_since_last_raise = expected_total - try: - instance._raise_exceptions() - except exceptions.MutationsExceptionGroup as exc: - assert list(exc.exceptions) == expected_exceptions - assert str(expected_total) in str(exc) - assert instance._entries_processed_since_last_raise == 0 - (instance._oldest_exceptions, instance._newest_exceptions) = ([], []) - instance._raise_exceptions() - - def test___aenter__(self): - """Should return self""" - with self._make_one() as instance: - assert instance.__enter__() == instance - - def test___aexit__(self): - """aexit should call close""" - with self._make_one() as instance: - with mock.patch.object(instance, "close") as close_mock: - instance.__exit__(None, None, None) - assert close_mock.call_count == 1 + def test_get_table_context_manager(self): + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_project_id = "project-id" + with mock.patch.object(Table, "close") as close_mock: + with self._make_client(project=expected_project_id) as client: + with client.get_table( + expected_instance_id, expected_table_id, expected_app_profile_id + ) as table: + time.sleep(0) + assert isinstance(table, Table) + assert table.table_id == expected_table_id + assert ( + table.table_name + == f"projects/{expected_project_id}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert table.instance_id == expected_instance_id + assert ( + table.instance_name + == f"projects/{expected_project_id}/instances/{expected_instance_id}" + ) + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + assert close_mock.call_count == 1 + + def test_multiple_pool_sizes(self): + pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] + for pool_size in pool_sizes: + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client._channel_refresh_tasks) == pool_size + client_duplicate = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client_duplicate._channel_refresh_tasks) == pool_size + assert str(pool_size) in str(client.transport) + client.close() + client_duplicate.close() def test_close(self): - """Should clean up all resources""" - with self._make_one() as instance: - with mock.patch.object(instance, "_schedule_flush") as flush_mock: - with mock.patch.object(instance, "_raise_exceptions") as raise_mock: - instance.close() - assert instance.closed is True - assert instance._flush_timer.done() is True - assert instance._flush_jobs == set() - assert flush_mock.call_count == 1 - assert raise_mock.call_count == 1 - - def test_close_w_exceptions(self): - """Raise exceptions on close""" - from google.cloud.bigtable.data import exceptions - - expected_total = 10 - expected_exceptions = [RuntimeError("mock")] - with self._make_one() as instance: - instance._oldest_exceptions = expected_exceptions - instance._entries_processed_since_last_raise = expected_total - try: - instance.close() - except exceptions.MutationsExceptionGroup as exc: - assert list(exc.exceptions) == expected_exceptions - assert str(expected_total) in str(exc) - assert instance._entries_processed_since_last_raise == 0 - (instance._oldest_exceptions, instance._newest_exceptions) = ([], []) - - def test__on_exit(self, recwarn): - """Should raise warnings if unflushed mutations exist""" - with self._make_one() as instance: - instance._on_exit() - assert len(recwarn) == 0 - num_left = 4 - instance._staged_entries = [mock.Mock()] * num_left - with pytest.warns(UserWarning) as w: - instance._on_exit() - assert len(w) == 1 - assert "unflushed mutations" in str(w[0].message).lower() - assert str(num_left) in str(w[0].message) - instance._closed.set() - instance._on_exit() - assert len(recwarn) == 0 - instance._staged_entries = [] - - def test_atexit_registration(self): - """Should run _on_exit on program termination""" - import atexit - - with mock.patch.object(atexit, "register") as register_mock: - assert register_mock.call_count == 0 - with self._make_one(): - assert register_mock.call_count == 1 - - def test_timeout_args_passed(self): - """ - batch_operation_timeout and batch_attempt_timeout should be used - in api calls - """ - if self.is_async(): - mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" - else: - mutate_path = "_sync._mutate_rows._MutateRowsOperation" - with mock.patch( - f"google.cloud.bigtable.data.{mutate_path}", return_value=mock.Mock() - ) as mutate_rows: - expected_operation_timeout = 17 - expected_attempt_timeout = 13 - with self._make_one( - batch_operation_timeout=expected_operation_timeout, - batch_attempt_timeout=expected_attempt_timeout, - ) as instance: - assert instance._operation_timeout == expected_operation_timeout - assert instance._attempt_timeout == expected_attempt_timeout - instance._execute_mutate_rows([self._make_mutation()]) - assert mutate_rows.call_count == 1 - kwargs = mutate_rows.call_args[1] - assert kwargs["operation_timeout"] == expected_operation_timeout - assert kwargs["attempt_timeout"] == expected_attempt_timeout + pool_size = 7 + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client._channel_refresh_tasks) == pool_size + tasks_list = list(client._channel_refresh_tasks) + for task in client._channel_refresh_tasks: + assert not task.done() + with mock.patch.object( + PooledBigtableGrpcTransport, "close", mock.Mock() + ) as close_mock: + client.close() + close_mock.assert_called_once() + close_mock.assert_called_once() + for task in tasks_list: + assert task.done() + assert client._channel_refresh_tasks == [] + + def test_context_manager(self): + close_mock = mock.Mock() + true_close = None + with self._make_client(project="project-id") as client: + true_close = client.close() + client.close = close_mock + for task in client._channel_refresh_tasks: + assert not task.done() + assert client.project == "project-id" + assert client._active_instances == set() + close_mock.assert_not_called() + close_mock.assert_called_once() + close_mock.assert_called_once() + true_close + + +class TestTable(ABC): + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data._sync.client import Table + + return Table + + def test_table_ctor(self): + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_operation_timeout = 123 + expected_attempt_timeout = 12 + expected_read_rows_operation_timeout = 1.5 + expected_read_rows_attempt_timeout = 0.5 + expected_mutate_rows_operation_timeout = 2.5 + expected_mutate_rows_attempt_timeout = 0.75 + client = self._make_client() + assert not client._active_instances + table = self._get_target_class()( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + default_operation_timeout=expected_operation_timeout, + default_attempt_timeout=expected_attempt_timeout, + default_read_rows_operation_timeout=expected_read_rows_operation_timeout, + default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout, + default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, + ) + time.sleep(0) + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + assert table.default_operation_timeout == expected_operation_timeout + assert table.default_attempt_timeout == expected_attempt_timeout + assert ( + table.default_read_rows_operation_timeout + == expected_read_rows_operation_timeout + ) + assert ( + table.default_read_rows_attempt_timeout + == expected_read_rows_attempt_timeout + ) + assert ( + table.default_mutate_rows_operation_timeout + == expected_mutate_rows_operation_timeout + ) + assert ( + table.default_mutate_rows_attempt_timeout + == expected_mutate_rows_attempt_timeout + ) + table._register_instance_task + assert table._register_instance_task.done() + assert not table._register_instance_task.cancelled() + assert table._register_instance_task.exception() is None + client.close() + + def test_table_ctor_defaults(self): + """should provide default timeout values and app_profile_id""" + expected_table_id = "table-id" + expected_instance_id = "instance-id" + client = self._make_client() + assert not client._active_instances + table = Table(client, expected_instance_id, expected_table_id) + time.sleep(0) + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id is None + assert table.client is client + assert table.default_operation_timeout == 60 + assert table.default_read_rows_operation_timeout == 600 + assert table.default_mutate_rows_operation_timeout == 600 + assert table.default_attempt_timeout == 20 + assert table.default_read_rows_attempt_timeout == 20 + assert table.default_mutate_rows_attempt_timeout == 60 + client.close() + + def test_table_ctor_invalid_timeout_values(self): + """bad timeout values should raise ValueError""" + client = self._make_client() + timeout_pairs = [ + ("default_operation_timeout", "default_attempt_timeout"), + ( + "default_read_rows_operation_timeout", + "default_read_rows_attempt_timeout", + ), + ( + "default_mutate_rows_operation_timeout", + "default_mutate_rows_attempt_timeout", + ), + ] + for operation_timeout, attempt_timeout in timeout_pairs: + with pytest.raises(ValueError) as e: + Table(client, "", "", **{attempt_timeout: -1}) + assert "attempt_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + Table(client, "", "", **{operation_timeout: -1}) + assert "operation_timeout must be greater than 0" in str(e.value) + client.close() + + def test_table_ctor_sync(self): + client = mock.Mock() + with pytest.raises(RuntimeError) as e: + Table(client, "instance-id", "table-id") + assert e.match("TableAsync must be created within an async event loop context.") @pytest.mark.parametrize( - "limit,in_e,start_e,end_e", + "fn_name,fn_args,retry_fn_path,extra_retryables", [ - (10, 0, (10, 0), (10, 0)), - (1, 10, (0, 0), (1, 1)), - (10, 1, (0, 0), (1, 0)), - (10, 10, (0, 0), (10, 0)), - (10, 11, (0, 0), (10, 1)), - (3, 20, (0, 0), (3, 3)), - (10, 20, (0, 0), (10, 10)), - (10, 21, (0, 0), (10, 10)), - (2, 1, (2, 0), (2, 1)), - (2, 1, (1, 0), (2, 0)), - (2, 2, (1, 0), (2, 1)), - (3, 1, (3, 1), (3, 2)), - (3, 3, (3, 1), (3, 3)), - (1000, 5, (999, 0), (1000, 4)), - (1000, 5, (0, 0), (5, 0)), - (1000, 5, (1000, 0), (1000, 5)), + ( + "read_rows_stream", + (ReadRowsQuery(),), + "google.api_core.retry.retry_target_stream_async", + (), + ), + ( + "read_rows", + (ReadRowsQuery(),), + "google.api_core.retry.retry_target_stream_async", + (), + ), + ( + "read_row", + (b"row_key",), + "google.api_core.retry.retry_target_stream_async", + (), + ), + ( + "read_rows_sharded", + ([ReadRowsQuery()],), + "google.api_core.retry.retry_target_stream_async", + (), + ), + ( + "row_exists", + (b"row_key",), + "google.api_core.retry.retry_target_stream_async", + (), + ), + ("sample_row_keys", (), "google.api_core.retry.retry_target_async", ()), + ( + "mutate_row", + (b"row_key", [mock.Mock()]), + "google.api_core.retry.retry_target_async", + (), + ), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mock.Mock()])],), + "google.api_core.retry.retry_target_async", + (_MutateRowsIncomplete,), + ), ], ) - def test__add_exceptions(self, limit, in_e, start_e, end_e): - """ - Test that the _add_exceptions function properly updates the - _oldest_exceptions and _newest_exceptions lists - Args: - - limit: the _exception_list_limit representing the max size of either list - - in_e: size of list of exceptions to send to _add_exceptions - - start_e: a tuple of ints representing the initial sizes of _oldest_exceptions and _newest_exceptions - - end_e: a tuple of ints representing the expected sizes of _oldest_exceptions and _newest_exceptions - """ - from collections import deque - - input_list = [RuntimeError(f"mock {i}") for i in range(in_e)] - mock_batcher = mock.Mock() - mock_batcher._oldest_exceptions = [ - RuntimeError(f"starting mock {i}") for i in range(start_e[0]) - ] - mock_batcher._newest_exceptions = deque( - [RuntimeError(f"starting mock {i}") for i in range(start_e[1])], - maxlen=limit, - ) - mock_batcher._exception_list_limit = limit - mock_batcher._exceptions_since_last_raise = 0 - self._get_target_class()._add_exceptions(mock_batcher, input_list) - assert len(mock_batcher._oldest_exceptions) == end_e[0] - assert len(mock_batcher._newest_exceptions) == end_e[1] - assert mock_batcher._exceptions_since_last_raise == in_e - oldest_list_diff = end_e[0] - start_e[0] - newest_list_diff = min(max(in_e - oldest_list_diff, 0), limit) - for i in range(oldest_list_diff): - assert mock_batcher._oldest_exceptions[i + start_e[0]] == input_list[i] - for i in range(1, newest_list_diff + 1): - assert mock_batcher._newest_exceptions[-i] == input_list[-i] - @pytest.mark.parametrize( "input_retryables,expected_retryables", [ @@ -856,34 +1099,85 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): ([4], [core_exceptions.DeadlineExceeded]), ], ) - def test_customizable_retryable_errors(self, input_retryables, expected_retryables): + def test_customizable_retryable_errors( + self, + input_retryables, + expected_retryables, + fn_name, + fn_args, + retry_fn_path, + extra_retryables, + ): """ Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - retryn_fn = ( - "retry_target_async" - if "Async" in self._get_target_class().__name__ - else "retry_target" - ) - with mock.patch.object( - google.api_core.retry, "if_exception_type" - ) as predicate_builder_mock: - with mock.patch.object(google.api_core.retry, retryn_fn) as retry_fn_mock: - table = None - with mock.patch("asyncio.create_task"): - table = Table(mock.Mock(), "instance", "table") - with self._make_one( - table, batch_retryable_errors=input_retryables - ) as instance: - assert instance._retryable_errors == expected_retryables - expected_predicate = lambda a: a in expected_retryables + with mock.patch(retry_fn_path) as retry_fn_mock: + with self._make_client() as client: + table = client.get_table("instance-id", "table-id") + expected_predicate = lambda a: a in expected_retryables + retry_fn_mock.side_effect = RuntimeError("stop early") + with mock.patch( + "google.api_core.retry.if_exception_type" + ) as predicate_builder_mock: predicate_builder_mock.return_value = expected_predicate - retry_fn_mock.side_effect = RuntimeError("stop early") - mutation = self._make_mutation(count=1, size=1) - instance._execute_mutate_rows([mutation]) + with pytest.raises(Exception): + test_fn = table.__getattribute__(fn_name) + test_fn(*fn_args, retryable_errors=input_retryables) predicate_builder_mock.assert_called_once_with( - *expected_retryables, _MutateRowsIncomplete + *expected_retryables, *extra_retryables ) retry_call_args = retry_fn_mock.call_args_list[0].args assert retry_call_args[1] is expected_predicate + + @pytest.mark.parametrize( + "fn_name,fn_args,gapic_fn", + [ + ("read_rows_stream", (ReadRowsQuery(),), "read_rows"), + ("read_rows", (ReadRowsQuery(),), "read_rows"), + ("read_row", (b"row_key",), "read_rows"), + ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), + ("row_exists", (b"row_key",), "read_rows"), + ("sample_row_keys", (), "sample_row_keys"), + ("mutate_row", (b"row_key", [mock.Mock()]), "mutate_row"), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + "mutate_rows", + ), + ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), + ( + "read_modify_write_row", + (b"row_key", mock.Mock()), + "read_modify_write_row", + ), + ], + ) + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): + """check that all requests attach proper metadata headers""" + profile = "profile" if include_app_profile else None + with mock.patch.object( + BigtableClient, gapic_fn, mock.mock.Mock() + ) as gapic_mock: + gapic_mock.side_effect = RuntimeError("stop early") + with self._make_client() as client: + table = Table(client, "instance-id", "table-id", profile) + try: + test_fn = table.__getattribute__(fn_name) + maybe_stream = test_fn(*fn_args) + [i for i in maybe_stream] + except Exception: + pass + kwargs = gapic_mock.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata From ef892bc5a3d2d8d90b25fd927d9b12fa8ad8f622 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Apr 2024 13:34:24 -0700 Subject: [PATCH 040/360] moved more logic into generated portion --- google/cloud/bigtable/data/_async/client.py | 77 ++++++------ google/cloud/bigtable/data/_sync/_autogen.py | 75 +++++++++--- google/cloud/bigtable/data/_sync/client.py | 112 +----------------- .../cloud/bigtable/data/_sync/sync_gen.yaml | 14 ++- .../transports/pooled_grpc_asyncio.py | 16 ++- sync_surface_generator.py | 1 - tests/unit/data/_async/test_client.py | 11 +- tests/unit/data/_sync/test_autogen.py | 18 ++- 8 files changed, 145 insertions(+), 179 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index b7090cc59..a2baea2c6 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -33,15 +33,15 @@ import os from functools import partial -from grpc.aio import Channel as AsyncChannel +from grpc import Channel from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( PooledBigtableGrpcAsyncIOTransport, - PooledChannel, ) +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import PooledChannel as AsyncPooledChannel from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest from google.cloud.client import ClientWithProject from google.cloud.environment_vars import BIGTABLE_EMULATOR # type: ignore @@ -123,7 +123,10 @@ def __init__( - ValueError if pool_size is less than 1 """ # set up transport in registry - transport_str = self._transport_init(pool_size) + # TODO: simplify when released: https://github.com/googleapis/gapic-generator-python/pull/1699 + transport_str = f"bt-{self._client_version()}-{pool_size}" + transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) + BigtableClientMeta._transport_registry[transport_str] = transport # set up client info headers for veneer library client_info = DEFAULT_CLIENT_INFO client_info.client_library_version = self._client_version() @@ -153,6 +156,7 @@ def __init__( client_options=client_options, client_info=client_info, ) + self._is_closed = asyncio.Event() self.transport = cast( PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport ) @@ -170,7 +174,11 @@ def __init__( RuntimeWarning, stacklevel=2, ) - self._prep_emulator_channel(self._emulator_host, pool_size) + self.transport._grpc_channel = AsyncPooledChannel( + pool_size=pool_size, + host=self._emulator_host, + insecure=True, + ) # refresh cached stubs to use emulator pool self.transport._stubs = {} self.transport._prep_wrapped_messages(client_info) @@ -186,29 +194,6 @@ def __init__( stacklevel=2, ) - def _transport_init(self, pool_size: int) -> str: - """ - Helper function for intiializing the transport object - - Different implementations for sync vs async client - """ - transport_str = f"pooled_grpc_asyncio_{pool_size}" - transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) - BigtableClientMeta._transport_registry[transport_str] = transport - return transport_str - - def _prep_emulator_channel(self, host:str, pool_size:int): - """ - Helper function for initializing emulator's insecure grpc channel - - Different implementations for sync vs async client - """ - self.transport._grpc_channel = PooledChannel( - pool_size=pool_size, - host=host, - insecure=True, - ) - @staticmethod def _client_version() -> str: """ @@ -222,7 +207,7 @@ def _start_background_channel_refresh(self) -> None: Raises: - RuntimeError if not called in an asyncio event loop """ - if not self._channel_refresh_tasks and not self._emulator_host: + if not self._channel_refresh_tasks and not self._emulator_host and not self._is_closed.is_set(): # raise RuntimeError if there is no event loop asyncio.get_running_loop() for channel_idx in range(self.transport.pool_size): @@ -238,6 +223,7 @@ async def close(self, timeout: float = 2.0): """ Cancel all background tasks """ + self._is_closed.set() for task in self._channel_refresh_tasks: task.cancel() group = asyncio.gather(*self._channel_refresh_tasks, return_exceptions=True) @@ -246,7 +232,7 @@ async def close(self, timeout: float = 2.0): self._channel_refresh_tasks = [] async def _ping_and_warm_instances( - self, channel: AsyncChannel, instance_key: _WarmedInstanceKey | None = None + self, channel: Channel, instance_key: _WarmedInstanceKey | None = None ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel @@ -267,8 +253,9 @@ async def _ping_and_warm_instances( request_serializer=PingAndWarmRequest.serialize, ) # prepare list of coroutines to run - tasks = [ - ping_rpc( + partial_list = [ + partial( + ping_rpc, request={"name": instance_name, "app_profile_id": app_profile_id}, metadata=[ ( @@ -280,8 +267,22 @@ async def _ping_and_warm_instances( ) for (instance_name, table_name, app_profile_id) in instance_list ] - # execute coroutines in parallel - result_list = await asyncio.gather(*tasks, return_exceptions=True) + return await self._execute_ping_and_warms(*partial_list) + + async def _execute_ping_and_warms(self, *fns) -> list[BaseException | None]: + """ + Execute batch of ping and warm requests in parallel + + Will have separate implementation for sync and async clients + + Args: + - fns: list of partial functions to execute ping and warm requests + Returns: + - list of results or exceptions from the ping requests + """ + # extract coroutine out of partials + coro_list = [fn() for fn in fns] + result_list = await asyncio.gather(*coro_list, return_exceptions=True) # return None in place of empty successful responses return [r or None for r in result_list] @@ -320,19 +321,21 @@ async def _manage_channel( channel = self.transport.channels[channel_idx] await self._ping_and_warm_instances(channel) # continuously refresh the channel every `refresh_interval` seconds - while True: + while not self._is_closed.is_set(): await asyncio.sleep(next_sleep) + if self._is_closed.is_set(): + break # prepare new channel for use new_channel = self.transport.grpc_channel._create_channel() await self._ping_and_warm_instances(new_channel) # cycle channel out of use, with long grace window before closure - start_timestamp = time.time() + start_timestamp = time.monotonic() await self.transport.replace_channel( - channel_idx, grace=grace_period, new_channel=new_channel + channel_idx, grace=grace_period, new_channel=new_channel, event=self._is_closed ) # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) - next_sleep = next_refresh - (time.time() - start_timestamp) + next_sleep = next_refresh - (time.monotonic() - start_timestamp) async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: """ diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 9de2ef21f..fdd9f880a 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -30,6 +30,7 @@ import concurrent.futures import functools import os +import random import threading import time import warnings @@ -76,13 +77,16 @@ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient -from google.cloud.bigtable_v2.services.bigtable.transports.grpc import ( - BigtableGrpcTransport, +from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledBigtableGrpcTransport, + PooledChannel, ) from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB from google.cloud.bigtable_v2.types import RowRange as RowRangePB from google.cloud.bigtable_v2.types import RowSet as RowSetPB +from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest from google.cloud.client import ClientWithProject from google.cloud.environment_vars import BIGTABLE_EMULATOR import google.auth._default @@ -897,7 +901,9 @@ def __init__( - RuntimeError if called outside of an async context (no running event loop) - ValueError if pool_size is less than 1 """ - transport_str = self._transport_init(pool_size) + transport_str = f"bt-{self._client_version()}-{pool_size}" + transport = PooledBigtableGrpcTransport.with_fixed_size(pool_size) + BigtableClientMeta._transport_registry[transport_str] = transport client_info = DEFAULT_CLIENT_INFO client_info.client_library_version = self._client_version() if type(client_options) is dict: @@ -923,7 +929,8 @@ def __init__( client_options=client_options, client_info=client_info, ) - self.transport = cast(BigtableGrpcTransport, self._gapic_client.transport) + self._is_closed = threading.Event() + self.transport = cast(PooledBigtableGrpcTransport, self._gapic_client.transport) self._active_instances: Set[_WarmedInstanceKey] = set() self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() @@ -934,7 +941,9 @@ def __init__( RuntimeWarning, stacklevel=2, ) - self._prep_emulator_channel(self._emulator_host, pool_size) + self.transport._grpc_channel = PooledChannel( + pool_size=pool_size, host=self._emulator_host, insecure=True + ) self.transport._stubs = {} self.transport._prep_wrapped_messages(client_info) else: @@ -947,12 +956,6 @@ def __init__( stacklevel=2, ) - def _transport_init(self, pool_size: int) -> str: - raise NotImplementedError("Function not implemented in sync class") - - def _prep_emulator_channel(self, host: str, pool_size: int): - raise NotImplementedError("Function not implemented in sync class") - @staticmethod def _client_version() -> str: raise NotImplementedError("Function not implemented in sync class") @@ -960,9 +963,6 @@ def _client_version() -> str: def _start_background_channel_refresh(self) -> None: raise NotImplementedError("Function not implemented in sync class") - def close(self, timeout: float = 2.0): - """Cancel all background tasks""" - def _ping_and_warm_instances( self, channel: Channel, instance_key: _WarmedInstanceKey | None = None ) -> list[BaseException | None]: @@ -977,6 +977,31 @@ def _ping_and_warm_instances( Returns: - sequence of results or exceptions from the ping requests """ + instance_list = ( + [instance_key] if instance_key is not None else self._active_instances + ) + ping_rpc = channel.unary_unary( + "/google.bigtable.v2.Bigtable/PingAndWarm", + request_serializer=PingAndWarmRequest.serialize, + ) + partial_list = [ + partial( + ping_rpc, + request={"name": instance_name, "app_profile_id": app_profile_id}, + metadata=[ + ( + "x-goog-request-params", + f"name={instance_name}&app_profile_id={app_profile_id}", + ) + ], + wait_for_ready=True, + ) + for (instance_name, table_name, app_profile_id) in instance_list + ] + return self._execute_ping_and_warms(*partial_list) + + def _execute_ping_and_warms(self, *fns) -> list[BaseException | None]: + raise NotImplementedError("Function not implemented in sync class") def _manage_channel( self, @@ -1004,6 +1029,28 @@ def _manage_channel( grace_period: time to allow previous channel to serve existing requests before closing, in seconds """ + first_refresh = self._channel_init_time + random.uniform( + refresh_interval_min, refresh_interval_max + ) + next_sleep = max(first_refresh - time.monotonic(), 0) + if next_sleep > 0: + channel = self.transport.channels[channel_idx] + self._ping_and_warm_instances(channel) + while not self._is_closed.is_set(): + self._is_closed.wait(next_sleep) + if self._is_closed.is_set(): + break + new_channel = self.transport.grpc_channel._create_channel() + self._ping_and_warm_instances(new_channel) + start_timestamp = time.monotonic() + self.transport.replace_channel( + channel_idx, + grace=grace_period, + new_channel=new_channel, + event=self._is_closed, + ) + next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) + next_sleep = next_refresh - (time.monotonic() - start_timestamp) def _register_instance( self, instance_id: str, owner: google.cloud.bigtable.data._sync.client.Table diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 06a9dbb6d..281a9f024 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -51,124 +51,19 @@ def _executor(self) -> concurrent.futures.ThreadPoolExecutor: self._executor_instance = concurrent.futures.ThreadPoolExecutor() return self._executor_instance - @property - def _is_closed(self) -> threading.Event: - if not hasattr(self, "_is_closed_instance"): - self._is_closed_instance = threading.Event() - return self._is_closed_instance - - def _transport_init(self, pool_size: int) -> str: - transport_str = f"pooled_grpc_{pool_size}" - transport = PooledBigtableGrpcTransport.with_fixed_size(pool_size) - BigtableClientMeta._transport_registry[transport_str] = transport - return transport_str - - def _prep_emulator_channel(self, host:str, pool_size: int) -> str: - self.transport._grpc_channel = PooledChannel( - pool_size=pool_size, - host=host, - insecure=True, - ) - @staticmethod def _client_version() -> str: return f"{google.cloud.bigtable.__version__}-data" def _start_background_channel_refresh(self) -> None: - if not self._channel_refresh_tasks and not self._emulator_host: + if not self._channel_refresh_tasks and not self._emulator_host and not self._is_closed.is_set(): for channel_idx in range(self.transport.pool_size): self._channel_refresh_tasks.append( self._executor.submit(self._manage_channel, channel_idx) ) - def _manage_channel( - self, - channel_idx: int, - refresh_interval_min: float = 60 * 35, - refresh_interval_max: float = 60 * 45, - grace_period: float = 60 * 10, - ) -> None: - """ - Background routine that periodically refreshes and warms a grpc channel - - The backend will automatically close channels after 60 minutes, so - `refresh_interval` + `grace_period` should be < 60 minutes - - Runs continuously until the client is closed - - Args: - channel_idx: index of the channel in the transport's channel pool - refresh_interval_min: minimum interval before initiating refresh - process in seconds. Actual interval will be a random value - between `refresh_interval_min` and `refresh_interval_max` - refresh_interval_max: maximum interval before initiating refresh - process in seconds. Actual interval will be a random value - between `refresh_interval_min` and `refresh_interval_max` - grace_period: time to allow previous channel to serve existing - requests before closing, in seconds - """ - first_refresh = self._channel_init_time + random.uniform( - refresh_interval_min, refresh_interval_max - ) - next_sleep = max(first_refresh - time.monotonic(), 0) - if next_sleep > 0: - # warm the current channel immediately - channel = self.transport.channels[channel_idx] - self._ping_and_warm_instances(channel) - # continuously refresh the channel every `refresh_interval` seconds - while not self._is_closed.is_set(): - # sleep until next refresh, or until client is closed - self._is_closed.wait(next_sleep) - if self._is_closed.is_set(): - break - # prepare new channel for use - new_channel = self.transport.grpc_channel._create_channel() - self._ping_and_warm_instances(new_channel) - # cycle channel out of use, with long grace window before closure - start_timestamp = time.monotonic() - self.transport.replace_channel( - channel_idx, grace=grace_period, new_channel=new_channel, event=self._is_closed - ) - # subtract the time spent waiting for the channel to be replaced - next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) - next_sleep = next_refresh - (time.monotonic() - start_timestamp) - - def _ping_and_warm_instances( - self, channel: grpc.Channel, instance_key: _WarmedInstanceKey | None = None - ) -> list[BaseException | None]: - """ - Prepares the backend for requests on a channel - - Pings each Bigtable instance registered in `_active_instances` on the client - - Args: - - channel: grpc channel to warm - - instance_key: if provided, only warm the instance associated with the key - Returns: - - sequence of results or exceptions from the ping requests - """ - instance_list = ( - [instance_key] if instance_key is not None else self._active_instances - ) - ping_rpc = channel.unary_unary( - "/google.bigtable.v2.Bigtable/PingAndWarm", - request_serializer=PingAndWarmRequest.serialize, - ) - # execute pings in parallel - futures_list = [] - for (instance_name, table_name, app_profile_id) in instance_list: - future = self._executor.submit( - ping_rpc, - request={"name": instance_name, "app_profile_id": app_profile_id}, - metadata=[ - ( - "x-goog-request-params", - f"name={instance_name}&app_profile_id={app_profile_id}", - ) - ], - wait_for_ready=True, - ) - futures_list.append(future) + def _execute_ping_and_warms(self, *fns) -> list[BaseException | None]: + futures_list = [self._executor.submit(f) for f in fns] results_list = [] for future in futures_list: try: @@ -189,7 +84,6 @@ def close(self) -> None: self._executor.shutdown(wait=False) self._channel_refresh_tasks = [] self.transport.close() - super().close() class Table(Table_SyncGen): diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index 8ef268068..d8072b30c 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -17,14 +17,12 @@ text_replacements: # Find and replace specific text patterns StopAsyncIteration: StopIteration Awaitable: None BigtableAsyncClient: BigtableClient - PooledBigtableGrpcAsyncIOTransport: BigtableGrpcTransport - AsyncChannel: Channel retry_target_async: retry_target retry_target_stream_async: retry_target_stream added_imports: - "from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient" - - "from google.cloud.bigtable_v2.services.bigtable.transports.grpc import BigtableGrpcTransport" + - "from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport, PooledChannel" - "from typing import Generator, Iterable, Iterator" - "from grpc import Channel" - "import google.cloud.bigtable.data.exceptions as bt_exceptions" @@ -48,8 +46,14 @@ classes: # Specify transformations for individual classes - path: google.cloud.bigtable.data._async.client.BigtableDataClientAsync autogen_sync_name: BigtableDataClient_SyncGen concrete_path: google.cloud.bigtable.data._sync.client.BigtableDataClient - pass_methods: ["close", "_ping_and_warm_instances", "_manage_channel"] - error_methods: ["_start_background_channel_refresh", "_client_version", "_prep_emulator_channel", "_transport_init"] + drop_methods: ["close"] + error_methods: ["_start_background_channel_refresh", "_client_version", "_execute_ping_and_warms"] + asyncio_replacements: + sleep: self._is_closed.wait + text_replacements: + PooledBigtableGrpcAsyncIOTransport: PooledBigtableGrpcTransport + AsyncChannel: Channel + AsyncPooledChannel: PooledChannel - path: google.cloud.bigtable.data._async.client.TableAsync autogen_sync_name: Table_SyncGen concrete_path: google.cloud.bigtable.data._sync.client.Table diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py index fa7ab4f59..864b4ecc2 100644 --- a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py +++ b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py @@ -150,7 +150,7 @@ async def wait_for_state_change(self, last_observed_state): raise NotImplementedError() async def replace_channel( - self, channel_idx, grace=1, new_channel=None + self, channel_idx, grace=1, new_channel=None, event=None ) -> aio.Channel: """ Replaces a channel in the pool with a fresh one. @@ -163,8 +163,11 @@ async def replace_channel( grace(Optional[float]): The time to wait for active RPCs to finish. If a grace period is not specified (by passing None for grace), all existing RPCs are cancelled immediately. + If event is set at close time, grace is ignored new_channel(grpc.aio.Channel): a new channel to insert into the pool at `channel_idx`. If `None`, a new channel will be created. + event(Optional[threading.Event]): an event to signal when the + replacement should be aborted. If set, grace is ignored. """ if channel_idx >= len(self._pool) or channel_idx < 0: raise ValueError( @@ -174,6 +177,8 @@ async def replace_channel( new_channel = self._create_channel() old_channel = self._pool[channel_idx] self._pool[channel_idx] = new_channel + if event is not None and not event.is_set(): + grace = None await old_channel.close(grace=grace) return new_channel @@ -397,7 +402,7 @@ def channels(self) -> List[grpc.Channel]: return self._grpc_channel._pool async def replace_channel( - self, channel_idx, grace=1, new_channel=None + self, channel_idx, grace=1, new_channel=None, event=None ) -> aio.Channel: """ Replaces a channel in the pool with a fresh one. @@ -408,13 +413,16 @@ async def replace_channel( Args: channel_idx(int): the channel index in the pool to replace grace(Optional[float]): The time to wait for active RPCs to - finished. If a grace period is not specified (by passing None for + finish. If a grace period is not specified (by passing None for grace), all existing RPCs are cancelled immediately. + If event is set at close time, grace is ignored new_channel(grpc.aio.Channel): a new channel to insert into the pool at `channel_idx`. If `None`, a new channel will be created. + event(Optional[threading.Event]): an event to signal when the + replacement should be aborted. If set, grace is ignored. """ return await self._grpc_channel.replace_channel( - channel_idx=channel_idx, grace=grace, new_channel=new_channel + channel_idx=channel_idx, grace=grace, new_channel=new_channel, event=event ) diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 59856488a..3d96d033b 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -246,7 +246,6 @@ def _create_error_node(node, error_msg): raise_node = ast.Raise(exc=exc_node, cause=None) node.body = [raise_node] - def get_imports(self, filename): """ Get the imports from a file, and do a find-and-replace against asyncio_replacements diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index c761dd186..4c2fae2f0 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -102,14 +102,15 @@ async def test_ctor(self): async def test_ctor_super_inits(self): from google.cloud.client import ClientWithProject from google.api_core import client_options as client_options_lib + from google.cloud.bigtable import __version__ as bigtable_version project = "project-id" pool_size = 11 credentials = AnonymousCredentials() client_options = {"api_endpoint": "foo.bar:1234"} options_parsed = client_options_lib.from_dict(client_options) - asyncio_portion = "_asyncio" if self.is_async else "" - transport_str = f"pooled_grpc{asyncio_portion}_{pool_size}" + asyncio_portion = "-async" if self.is_async else "" + transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: bigtable_client_init.return_value = None with mock.patch.object( @@ -320,6 +321,7 @@ async def test__ping_and_warm_instances(self): test ping and warm with mocked asyncio.gather """ client_mock = mock.Mock() + client_mock._execute_ping_and_warms = lambda *args: self._get_target_class()._execute_ping_and_warms(client_mock, *args) gather_tuple = (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") with mock.patch.object(*gather_tuple, AsyncMock()) as gather: if self.is_async: @@ -381,6 +383,7 @@ async def test__ping_and_warm_single_instance(self): should be able to call ping and warm with single instance """ client_mock = mock.Mock() + client_mock._execute_ping_and_warms = lambda *args: self._get_target_class()._execute_ping_and_warms(client_mock, *args) gather_tuple = (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") with mock.patch.object(*gather_tuple, AsyncMock()) as gather: gather.side_effect = lambda *args, **kwargs: [mock.Mock() for _ in args] @@ -460,9 +463,7 @@ async def test__manage_channel_ping_and_warm(self): import threading client_mock = mock.Mock() - if not self.is_async: - # make sure loop is entered - client_mock._is_closed.is_set.return_value = False + client_mock._is_closed.is_set.return_value = False client_mock._channel_init_time = time.monotonic() channel_list = [mock.Mock(), mock.Mock()] client_mock.transport.channels = channel_list diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py index 9c0791bc7..10e7e1ad4 100644 --- a/tests/unit/data/_sync/test_autogen.py +++ b/tests/unit/data/_sync/test_autogen.py @@ -91,14 +91,15 @@ def test_ctor(self): def test_ctor_super_inits(self): from google.cloud.client import ClientWithProject from google.api_core import client_options as client_options_lib + from google.cloud.bigtable import __version__ as bigtable_version project = "project-id" pool_size = 11 credentials = AnonymousCredentials() client_options = {"api_endpoint": "foo.bar:1234"} options_parsed = client_options_lib.from_dict(client_options) - asyncio_portion = "_asyncio" if self.is_async else "" - transport_str = f"pooled_grpc{asyncio_portion}_{pool_size}" + asyncio_portion = "-async" if self.is_async else "" + transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" with mock.patch.object(BigtableClient, "__init__") as bigtable_client_init: bigtable_client_init.return_value = None with mock.patch.object( @@ -275,6 +276,11 @@ def test__start_background_channel_refresh(self, pool_size): def test__ping_and_warm_instances(self): """test ping and warm with mocked asyncio.gather""" client_mock = mock.Mock() + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) gather_tuple = ( (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") ) @@ -327,6 +333,11 @@ def test__ping_and_warm_instances(self): def test__ping_and_warm_single_instance(self): """should be able to call ping and warm with single instance""" client_mock = mock.Mock() + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) gather_tuple = ( (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") ) @@ -392,8 +403,7 @@ def test__manage_channel_ping_and_warm(self): import threading client_mock = mock.Mock() - if not self.is_async: - client_mock._is_closed.is_set.return_value = False + client_mock._is_closed.is_set.return_value = False client_mock._channel_init_time = time.monotonic() channel_list = [mock.Mock(), mock.Mock()] client_mock.transport.channels = channel_list From fe572bc3b0034b9ee889ecff02f1cab368f8d84e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Apr 2024 14:59:29 -0700 Subject: [PATCH 041/360] got table tests passing --- google/cloud/bigtable/data/_async/client.py | 14 ++-- google/cloud/bigtable/data/_sync/_autogen.py | 20 ++--- google/cloud/bigtable/data/_sync/client.py | 18 ++++- .../cloud/bigtable/data/_sync/sync_gen.yaml | 2 +- .../cloud/bigtable/data/_sync/unit_tests.yaml | 2 + tests/unit/data/_async/test_client.py | 40 ++++++---- tests/unit/data/_sync/test_autogen.py | 77 ++++++------------- 7 files changed, 81 insertions(+), 92 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index a2baea2c6..062b33bc5 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -559,11 +559,11 @@ def __init__( default_mutate_rows_retryable_errors or () ) self.default_retryable_errors = default_retryable_errors or () - self._register_with_client() + self._register_instance_future: asyncio.Future[None] = self._register_with_client() - def _register_with_client(self): + def _register_with_client(self) -> asyncio.Future[None]: """ Calls the client's _register_instance method to warm the grpc channels for this instance @@ -571,7 +571,7 @@ def _register_with_client(self): """ # raises RuntimeError if called outside of an async context (no running event loop) try: - self._register_instance_task = asyncio.create_task( + return asyncio.create_task( self.client._register_instance(self.instance_id, self) ) except RuntimeError as e: @@ -1282,8 +1282,8 @@ async def close(self): """ Called to close the Table instance and release any resources held by it. """ - if self._register_instance_task: - self._register_instance_task.cancel() + if self._register_instance_future: + self._register_instance_future.cancel() await self.client._remove_instance_registration(self.instance_id, self) async def __aenter__(self): @@ -1293,8 +1293,8 @@ async def __aenter__(self): Ensure registration task has time to run, so that grpc channels will be warmed for the specified instance """ - if self._register_instance_task: - await self._register_instance_task + if self._register_instance_future: + await self._register_instance_future return self async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index fdd9f880a..70d336c72 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -1271,9 +1271,11 @@ def __init__( default_mutate_rows_retryable_errors or () ) self.default_retryable_errors = default_retryable_errors or () - self._register_with_client() + self._register_instance_future: concurrent.futures.Future[ + None + ] = self._register_with_client() - def _register_with_client(self): + def _register_with_client(self) -> concurrent.futures.Future[None]: raise NotImplementedError("Function not implemented in sync class") def read_rows_stream( @@ -1950,20 +1952,12 @@ def read_modify_write_row( def close(self): """Called to close the Table instance and release any resources held by it.""" - if self._register_instance_task: - self._register_instance_task.cancel() + if self._register_instance_future: + self._register_instance_future.cancel() self.client._remove_instance_registration(self.instance_id, self) def __enter__(self): - """ - Implement async context manager protocol - - Ensure registration task has time to run, so that - grpc channels will be warmed for the specified instance - """ - if self._register_instance_task: - self._register_instance_task - return self + raise NotImplementedError("Function not implemented in sync class") def __exit__(self, exc_type, exc_val, exc_tb): """ diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 281a9f024..9f0aea9e8 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -88,9 +88,10 @@ def close(self) -> None: class Table(Table_SyncGen): - def _register_with_client(self): - self.client._register_instance(self.instance_id, self) - self._register_instance_task = None + def _register_with_client(self) -> concurrent.futures.Future[None]: + return self.client._executor.submit( + self.client._register_instance, self.instance_id, self + ) def _shard_batch_helper(self, kwargs_list: list[dict]) -> list[list[Row] | BaseException]: with concurrent.futures.ThreadPoolExecutor() as executor: @@ -102,3 +103,14 @@ def _shard_batch_helper(self, kwargs_list: list[dict]) -> list[list[Row] | BaseE else: results_list.append(future.result()) return results_list + + def __enter__(self): + """ + Implement context manager protocol + + Ensure registration task has time to run, so that + grpc channels will be warmed for the specified instance + """ + if self._register_instance_future: + self._register_instance_future.result() + return self diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index d8072b30c..850369331 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -57,6 +57,6 @@ classes: # Specify transformations for individual classes - path: google.cloud.bigtable.data._async.client.TableAsync autogen_sync_name: Table_SyncGen concrete_path: google.cloud.bigtable.data._sync.client.Table - error_methods: ["_register_with_client", "_shard_batch_helper"] + error_methods: ["_register_with_client", "_shard_batch_helper", "__aenter__"] save_path: "google/cloud/bigtable/data/_sync/_autogen.py" diff --git a/google/cloud/bigtable/data/_sync/unit_tests.yaml b/google/cloud/bigtable/data/_sync/unit_tests.yaml index caad4ab15..103168f6c 100644 --- a/google/cloud/bigtable/data/_sync/unit_tests.yaml +++ b/google/cloud/bigtable/data/_sync/unit_tests.yaml @@ -82,6 +82,8 @@ classes: _get_target_class: | from google.cloud.bigtable.data._sync.client import Table return Table + is_async: "return False" + drop_methods: ["test_table_ctor_sync"] #- path: tests.unit.data._async.test_client.TestReadRowsShardedAsync # autogen_sync_name: TestReadRowsSharded #- path: tests.unit.data._async.test_client.TestReadRowsAsync diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 4c2fae2f0..0b5c10f58 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1071,6 +1071,9 @@ def _get_target_class(): from google.cloud.bigtable.data._async.client import TableAsync return TableAsync + @property + def is_async(self): + return True @pytest.mark.asyncio async def test_table_ctor(self): @@ -1129,10 +1132,10 @@ async def test_table_ctor(self): == expected_mutate_rows_attempt_timeout ) # ensure task reaches completion - await table._register_instance_task - assert table._register_instance_task.done() - assert not table._register_instance_task.cancelled() - assert table._register_instance_task.exception() is None + await table._register_instance_future + assert table._register_instance_future.done() + assert not table._register_instance_future.cancelled() + assert table._register_instance_future.exception() is None await client.close() @pytest.mark.asyncio @@ -1206,49 +1209,49 @@ def test_table_ctor_sync(self): @pytest.mark.asyncio # iterate over all retryable rpcs @pytest.mark.parametrize( - "fn_name,fn_args,retry_fn_path,extra_retryables", + "fn_name,fn_args,is_stream,extra_retryables", [ ( "read_rows_stream", (ReadRowsQuery(),), - "google.api_core.retry.retry_target_stream_async", + True, (), ), ( "read_rows", (ReadRowsQuery(),), - "google.api_core.retry.retry_target_stream_async", + True, (), ), ( "read_row", (b"row_key",), - "google.api_core.retry.retry_target_stream_async", + True, (), ), ( "read_rows_sharded", ([ReadRowsQuery()],), - "google.api_core.retry.retry_target_stream_async", + True, (), ), ( "row_exists", (b"row_key",), - "google.api_core.retry.retry_target_stream_async", + True, (), ), - ("sample_row_keys", (), "google.api_core.retry.retry_target_async", ()), + ("sample_row_keys", (), False, ()), ( "mutate_row", (b"row_key", [mock.Mock()]), - "google.api_core.retry.retry_target_async", + False, (), ), ( "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mock.Mock()])],), - "google.api_core.retry.retry_target_async", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + False, (_MutateRowsIncomplete,), ), ], @@ -1283,14 +1286,19 @@ async def test_customizable_retryable_errors( expected_retryables, fn_name, fn_args, - retry_fn_path, + is_stream, extra_retryables, ): """ Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - with mock.patch(retry_fn_path) as retry_fn_mock: + retry_fn = "retry_target" + if is_stream: + retry_fn += "_stream" + if self.is_async: + retry_fn += "_async" + with mock.patch(f"google.api_core.retry.{retry_fn}") as retry_fn_mock: async with self._make_client() as client: table = client.get_table("instance-id", "table-id") expected_predicate = lambda a: a in expected_retryables # noqa diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py index 10e7e1ad4..8baafb0d7 100644 --- a/tests/unit/data/_sync/test_autogen.py +++ b/tests/unit/data/_sync/test_autogen.py @@ -929,6 +929,10 @@ def _get_target_class(): return Table + @property + def is_async(self): + return False + def test_table_ctor(self): from google.cloud.bigtable.data._async.client import _WarmedInstanceKey @@ -983,10 +987,10 @@ def test_table_ctor(self): table.default_mutate_rows_attempt_timeout == expected_mutate_rows_attempt_timeout ) - table._register_instance_task - assert table._register_instance_task.done() - assert not table._register_instance_task.cancelled() - assert table._register_instance_task.exception() is None + table._register_instance_future + assert table._register_instance_future.done() + assert not table._register_instance_future.cancelled() + assert table._register_instance_future.exception() is None client.close() def test_table_ctor_defaults(self): @@ -1032,56 +1036,20 @@ def test_table_ctor_invalid_timeout_values(self): assert "operation_timeout must be greater than 0" in str(e.value) client.close() - def test_table_ctor_sync(self): - client = mock.Mock() - with pytest.raises(RuntimeError) as e: - Table(client, "instance-id", "table-id") - assert e.match("TableAsync must be created within an async event loop context.") - @pytest.mark.parametrize( - "fn_name,fn_args,retry_fn_path,extra_retryables", + "fn_name,fn_args,is_stream,extra_retryables", [ - ( - "read_rows_stream", - (ReadRowsQuery(),), - "google.api_core.retry.retry_target_stream_async", - (), - ), - ( - "read_rows", - (ReadRowsQuery(),), - "google.api_core.retry.retry_target_stream_async", - (), - ), - ( - "read_row", - (b"row_key",), - "google.api_core.retry.retry_target_stream_async", - (), - ), - ( - "read_rows_sharded", - ([ReadRowsQuery()],), - "google.api_core.retry.retry_target_stream_async", - (), - ), - ( - "row_exists", - (b"row_key",), - "google.api_core.retry.retry_target_stream_async", - (), - ), - ("sample_row_keys", (), "google.api_core.retry.retry_target_async", ()), - ( - "mutate_row", - (b"row_key", [mock.Mock()]), - "google.api_core.retry.retry_target_async", - (), - ), + ("read_rows_stream", (ReadRowsQuery(),), True, ()), + ("read_rows", (ReadRowsQuery(),), True, ()), + ("read_row", (b"row_key",), True, ()), + ("read_rows_sharded", ([ReadRowsQuery()],), True, ()), + ("row_exists", (b"row_key",), True, ()), + ("sample_row_keys", (), False, ()), + ("mutate_row", (b"row_key", [mock.Mock()]), False, ()), ( "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mock.Mock()])],), - "google.api_core.retry.retry_target_async", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + False, (_MutateRowsIncomplete,), ), ], @@ -1115,14 +1083,19 @@ def test_customizable_retryable_errors( expected_retryables, fn_name, fn_args, - retry_fn_path, + is_stream, extra_retryables, ): """ Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - with mock.patch(retry_fn_path) as retry_fn_mock: + retry_fn = "retry_target" + if is_stream: + retry_fn += "_stream" + if self.is_async: + retry_fn += "_async" + with mock.patch(f"google.api_core.retry.{retry_fn}") as retry_fn_mock: with self._make_client() as client: table = client.get_table("instance-id", "table-id") expected_predicate = lambda a: a in expected_retryables From 1d428a6666626ebb64d9b8e567b4d8707ad3686a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Apr 2024 16:52:16 -0700 Subject: [PATCH 042/360] got all client tests passing --- google/cloud/bigtable/data/_async/client.py | 69 +- google/cloud/bigtable/data/_sync/_autogen.py | 61 +- google/cloud/bigtable/data/_sync/client.py | 3 +- .../cloud/bigtable/data/_sync/unit_tests.yaml | 77 +- tests/unit/data/_async/test_client.py | 94 +- tests/unit/data/_sync/test_autogen.py | 5218 +++++++++++++---- 6 files changed, 4370 insertions(+), 1152 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 062b33bc5..ac42b580a 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -60,17 +60,10 @@ from google.cloud.bigtable.data.exceptions import FailedQueryShardError from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup +from google.cloud.bigtable.data import _helpers +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync -from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data._helpers import _WarmedInstanceKey -from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT -from google.cloud.bigtable.data._helpers import _make_metadata -from google.cloud.bigtable.data._helpers import _retry_exception_factory -from google.cloud.bigtable.data._helpers import _validate_timeouts -from google.cloud.bigtable.data._helpers import _get_retryable_errors -from google.cloud.bigtable.data._helpers import _get_timeouts -from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule @@ -161,10 +154,10 @@ def __init__( PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport ) # keep track of active instances to for warmup on channel refresh - self._active_instances: Set[_WarmedInstanceKey] = set() + self._active_instances: Set[_helpers._WarmedInstanceKey] = set() # keep track of table objects associated with each instance # only remove instance from _active_instances when all associated tables remove it - self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} + self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() self._channel_refresh_tasks: list[asyncio.Task[None]] = [] if self._emulator_host is not None: @@ -232,7 +225,7 @@ async def close(self, timeout: float = 2.0): self._channel_refresh_tasks = [] async def _ping_and_warm_instances( - self, channel: Channel, instance_key: _WarmedInstanceKey | None = None + self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel @@ -352,7 +345,7 @@ async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: owners call _remove_instance_registration """ instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _WarmedInstanceKey( + instance_key = _helpers._WarmedInstanceKey( instance_name, owner.table_name, owner.app_profile_id ) self._instance_owners.setdefault(instance_key, set()).add(id(owner)) @@ -385,7 +378,7 @@ async def _remove_instance_registration( - True if instance was removed """ instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _WarmedInstanceKey( + instance_key = _helpers._WarmedInstanceKey( instance_name, owner.table_name, owner.app_profile_id ) owner_list = self._instance_owners.get(instance_key, set()) @@ -518,15 +511,15 @@ def __init__( # NOTE: any changes to the signature of this method should also be reflected # in client.get_table() # validate timeouts - _validate_timeouts( + _helpers._validate_timeouts( default_operation_timeout, default_attempt_timeout, allow_none=True ) - _validate_timeouts( + _helpers._validate_timeouts( default_read_rows_operation_timeout, default_read_rows_attempt_timeout, allow_none=True, ) - _validate_timeouts( + _helpers._validate_timeouts( default_mutate_rows_operation_timeout, default_mutate_rows_attempt_timeout, allow_none=True, @@ -618,10 +611,10 @@ async def read_rows_stream( from any retries that failed - GoogleAPIError: raised if the request encounters an unrecoverable error """ - operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) - retryable_excs = _get_retryable_errors(retryable_errors, self) + retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) row_merger = _ReadRowsOperationAsync( query, @@ -778,16 +771,16 @@ async def read_rows_sharded( if not sharded_query: raise ValueError("empty sharded_query") # reduce operation_timeout between batches - operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) - timeout_generator = _attempt_timeout_generator( + timeout_generator = _helpers._attempt_timeout_generator( operation_timeout, operation_timeout ) # submit shards in batches if the number of shards goes over _CONCURRENCY_LIMIT batched_queries = [ - sharded_query[i : i + _CONCURRENCY_LIMIT] - for i in range(0, len(sharded_query), _CONCURRENCY_LIMIT) + sharded_query[i : i + _helpers._CONCURRENCY_LIMIT] + for i in range(0, len(sharded_query), _helpers._CONCURRENCY_LIMIT) ] # run batches and collect results results_list = [] @@ -931,20 +924,20 @@ async def sample_row_keys( - GoogleAPIError: raised if the request encounters an unrecoverable error """ # prepare timeouts - operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) - attempt_timeout_gen = _attempt_timeout_generator( + attempt_timeout_gen = _helpers._attempt_timeout_generator( attempt_timeout, operation_timeout ) # prepare retryable - retryable_excs = _get_retryable_errors(retryable_errors, self) + retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) predicate = retries.if_exception_type(*retryable_excs) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) # prepare request - metadata = _make_metadata(self.table_name, self.app_profile_id) + metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) async def execute_rpc(): results = await self.client._gapic_client.sample_row_keys( @@ -961,7 +954,7 @@ async def execute_rpc(): predicate, sleep_generator, operation_timeout, - exception_factory=_retry_exception_factory, + exception_factory=_helpers._retry_exception_factory, ) def mutations_batcher( @@ -1060,7 +1053,7 @@ async def mutate_row( safely retried. - ValueError if invalid arguments are provided """ - operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) @@ -1071,7 +1064,7 @@ async def mutate_row( if all(mutation.is_idempotent() for mutation in mutations_list): # mutations are all idempotent and safe to retry predicate = retries.if_exception_type( - *_get_retryable_errors(retryable_errors, self) + *_helpers._get_retryable_errors(retryable_errors, self) ) else: # mutations should not be retried @@ -1086,7 +1079,7 @@ async def mutate_row( table_name=self.table_name, app_profile_id=self.app_profile_id, timeout=attempt_timeout, - metadata=_make_metadata(self.table_name, self.app_profile_id), + metadata=_helpers._make_metadata(self.table_name, self.app_profile_id), retry=None, ) return await retries.retry_target_async( @@ -1094,7 +1087,7 @@ async def mutate_row( predicate, sleep_generator, operation_timeout, - exception_factory=_retry_exception_factory, + exception_factory=_helpers._retry_exception_factory, ) async def bulk_mutate_rows( @@ -1140,10 +1133,10 @@ async def bulk_mutate_rows( Contains details about any failed entries in .exceptions - ValueError if invalid arguments are provided """ - operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) - retryable_excs = _get_retryable_errors(retryable_errors, self) + retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) operation = _MutateRowsOperationAsync( self.client._gapic_client, @@ -1199,7 +1192,7 @@ async def check_and_mutate_row( Raises: - GoogleAPIError exceptions from grpc call """ - operation_timeout, _ = _get_timeouts(operation_timeout, None, self) + operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) if true_case_mutations is not None and not isinstance( true_case_mutations, list ): @@ -1210,7 +1203,7 @@ async def check_and_mutate_row( ): false_case_mutations = [false_case_mutations] false_case_list = [m._to_pb() for m in false_case_mutations or []] - metadata = _make_metadata(self.table_name, self.app_profile_id) + metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) result = await self.client._gapic_client.check_and_mutate_row( true_mutations=true_case_list, false_mutations=false_case_list, @@ -1258,14 +1251,14 @@ async def read_modify_write_row( - GoogleAPIError exceptions from grpc call - ValueError if invalid arguments are provided """ - operation_timeout, _ = _get_timeouts(operation_timeout, None, self) + operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) if operation_timeout <= 0: raise ValueError("operation_timeout must be greater than 0") if rules is not None and not isinstance(rules, list): rules = [rules] if not rules: raise ValueError("rules must contain at least one item") - metadata = _make_metadata(self.table_name, self.app_profile_id) + metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) result = await self.client._gapic_client.read_modify_write_row( rules=[rule._to_pb() for rule in rules], row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 70d336c72..5486d376b 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -50,14 +50,11 @@ from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT -from google.cloud.bigtable.data._helpers import _WarmedInstanceKey from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory -from google.cloud.bigtable.data._helpers import _validate_timeouts from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data.exceptions import FailedQueryShardError from google.cloud.bigtable.data.exceptions import InvalidChunk @@ -931,8 +928,8 @@ def __init__( ) self._is_closed = threading.Event() self.transport = cast(PooledBigtableGrpcTransport, self._gapic_client.transport) - self._active_instances: Set[_WarmedInstanceKey] = set() - self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} + self._active_instances: Set[_helpers._WarmedInstanceKey] = set() + self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() self._channel_refresh_tasks: list[threading.Thread[None]] = [] if self._emulator_host is not None: @@ -964,7 +961,7 @@ def _start_background_channel_refresh(self) -> None: raise NotImplementedError("Function not implemented in sync class") def _ping_and_warm_instances( - self, channel: Channel, instance_key: _WarmedInstanceKey | None = None + self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel @@ -1069,7 +1066,7 @@ def _register_instance( owners call _remove_instance_registration """ instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _WarmedInstanceKey( + instance_key = _helpers._WarmedInstanceKey( instance_name, owner.table_name, owner.app_profile_id ) self._instance_owners.setdefault(instance_key, set()).add(id(owner)) @@ -1099,7 +1096,7 @@ def _remove_instance_registration( - True if instance was removed """ instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _WarmedInstanceKey( + instance_key = _helpers._WarmedInstanceKey( instance_name, owner.table_name, owner.app_profile_id ) owner_list = self._instance_owners.get(instance_key, set()) @@ -1233,15 +1230,15 @@ def __init__( Raises: - RuntimeError if called outside of an async context (no running event loop) """ - _validate_timeouts( + _helpers._validate_timeouts( default_operation_timeout, default_attempt_timeout, allow_none=True ) - _validate_timeouts( + _helpers._validate_timeouts( default_read_rows_operation_timeout, default_read_rows_attempt_timeout, allow_none=True, ) - _validate_timeouts( + _helpers._validate_timeouts( default_mutate_rows_operation_timeout, default_mutate_rows_attempt_timeout, allow_none=True, @@ -1317,10 +1314,10 @@ def read_rows_stream( from any retries that failed - GoogleAPIError: raised if the request encounters an unrecoverable error """ - (operation_timeout, attempt_timeout) = _get_timeouts( + (operation_timeout, attempt_timeout) = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) - retryable_excs = _get_retryable_errors(retryable_errors, self) + retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) row_merger = google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation( query, self, @@ -1475,15 +1472,15 @@ def read_rows_sharded( """ if not sharded_query: raise ValueError("empty sharded_query") - (operation_timeout, attempt_timeout) = _get_timeouts( + (operation_timeout, attempt_timeout) = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) - timeout_generator = _attempt_timeout_generator( + timeout_generator = _helpers._attempt_timeout_generator( operation_timeout, operation_timeout ) batched_queries = [ - sharded_query[i : i + _CONCURRENCY_LIMIT] - for i in range(0, len(sharded_query), _CONCURRENCY_LIMIT) + sharded_query[i : i + _helpers._CONCURRENCY_LIMIT] + for i in range(0, len(sharded_query), _helpers._CONCURRENCY_LIMIT) ] results_list = [] error_dict = {} @@ -1615,16 +1612,16 @@ def sample_row_keys( from any retries that failed - GoogleAPIError: raised if the request encounters an unrecoverable error """ - (operation_timeout, attempt_timeout) = _get_timeouts( + (operation_timeout, attempt_timeout) = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) - attempt_timeout_gen = _attempt_timeout_generator( + attempt_timeout_gen = _helpers._attempt_timeout_generator( attempt_timeout, operation_timeout ) - retryable_excs = _get_retryable_errors(retryable_errors, self) + retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) predicate = retries.if_exception_type(*retryable_excs) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - metadata = _make_metadata(self.table_name, self.app_profile_id) + metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) def execute_rpc(): results = self.client._gapic_client.sample_row_keys( @@ -1641,7 +1638,7 @@ def execute_rpc(): predicate, sleep_generator, operation_timeout, - exception_factory=_retry_exception_factory, + exception_factory=_helpers._retry_exception_factory, ) def mutations_batcher( @@ -1740,7 +1737,7 @@ def mutate_row( safely retried. - ValueError if invalid arguments are provided """ - (operation_timeout, attempt_timeout) = _get_timeouts( + (operation_timeout, attempt_timeout) = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) if not mutations: @@ -1748,7 +1745,7 @@ def mutate_row( mutations_list = mutations if isinstance(mutations, list) else [mutations] if all((mutation.is_idempotent() for mutation in mutations_list)): predicate = retries.if_exception_type( - *_get_retryable_errors(retryable_errors, self) + *_helpers._get_retryable_errors(retryable_errors, self) ) else: predicate = retries.if_exception_type() @@ -1760,7 +1757,7 @@ def mutate_row( table_name=self.table_name, app_profile_id=self.app_profile_id, timeout=attempt_timeout, - metadata=_make_metadata(self.table_name, self.app_profile_id), + metadata=_helpers._make_metadata(self.table_name, self.app_profile_id), retry=None, ) return retries.retry_target( @@ -1768,7 +1765,7 @@ def mutate_row( predicate, sleep_generator, operation_timeout, - exception_factory=_retry_exception_factory, + exception_factory=_helpers._retry_exception_factory, ) def bulk_mutate_rows( @@ -1814,10 +1811,10 @@ def bulk_mutate_rows( Contains details about any failed entries in .exceptions - ValueError if invalid arguments are provided """ - (operation_timeout, attempt_timeout) = _get_timeouts( + (operation_timeout, attempt_timeout) = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) - retryable_excs = _get_retryable_errors(retryable_errors, self) + retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) operation = google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation( self.client._gapic_client, self, @@ -1872,7 +1869,7 @@ def check_and_mutate_row( Raises: - GoogleAPIError exceptions from grpc call """ - (operation_timeout, _) = _get_timeouts(operation_timeout, None, self) + (operation_timeout, _) = _helpers._get_timeouts(operation_timeout, None, self) if true_case_mutations is not None and ( not isinstance(true_case_mutations, list) ): @@ -1883,7 +1880,7 @@ def check_and_mutate_row( ): false_case_mutations = [false_case_mutations] false_case_list = [m._to_pb() for m in false_case_mutations or []] - metadata = _make_metadata(self.table_name, self.app_profile_id) + metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) result = self.client._gapic_client.check_and_mutate_row( true_mutations=true_case_list, false_mutations=false_case_list, @@ -1931,14 +1928,14 @@ def read_modify_write_row( - GoogleAPIError exceptions from grpc call - ValueError if invalid arguments are provided """ - (operation_timeout, _) = _get_timeouts(operation_timeout, None, self) + (operation_timeout, _) = _helpers._get_timeouts(operation_timeout, None, self) if operation_timeout <= 0: raise ValueError("operation_timeout must be greater than 0") if rules is not None and (not isinstance(rules, list)): rules = [rules] if not rules: raise ValueError("rules must contain at least one item") - metadata = _make_metadata(self.table_name, self.app_profile_id) + metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) result = self.client._gapic_client.read_modify_write_row( rules=[rule._to_pb() for rule in rules], row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 9f0aea9e8..2d1d41814 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -94,8 +94,7 @@ def _register_with_client(self) -> concurrent.futures.Future[None]: ) def _shard_batch_helper(self, kwargs_list: list[dict]) -> list[list[Row] | BaseException]: - with concurrent.futures.ThreadPoolExecutor() as executor: - futures_list = [executor.submit(self.read_rows, **kwargs) for kwargs in kwargs_list] + futures_list = [self.client._executor.submit(self.read_rows, **kwargs) for kwargs in kwargs_list] results_list: list[list[Row] | BaseException] = [] for future in futures_list: if future.exception(): diff --git a/google/cloud/bigtable/data/_sync/unit_tests.yaml b/google/cloud/bigtable/data/_sync/unit_tests.yaml index 103168f6c..1475f94b1 100644 --- a/google/cloud/bigtable/data/_sync/unit_tests.yaml +++ b/google/cloud/bigtable/data/_sync/unit_tests.yaml @@ -27,38 +27,39 @@ text_replacements: # Find and replace specific text patterns AsyncMock: mock.Mock retry_target_async: retry_target TestBigtableDataClientAsync: TestBigtableDataClient + TestReadRowsAsync: TestReadRows assert_awaited_once: assert_called_once assert_awaited: assert_called_once grpc_helpers_async: grpc_helpers classes: - #- path: tests.unit.data._async.test__mutate_rows.TestMutateRowsOperation - # autogen_sync_name: TestMutateRowsOperation - # replace_methods: - # _target_class: | - # from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation - # return _MutateRowsOperation - #- path: tests.unit.data._async.test__read_rows.TestReadRowsOperation - # autogen_sync_name: TestReadRowsOperation - # text_replacements: - # test_aclose: test_close - # replace_methods: - # _get_target_class: | - # from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - # return _ReadRowsOperation - #- path: tests.unit.data._async.test_mutations_batcher.Test_FlowControl - # autogen_sync_name: Test_FlowControl - # replace_methods: - # _target_class: | - # from google.cloud.bigtable.data._sync.mutations_batcher import _FlowControl - # return _FlowControl - #- path: tests.unit.data._async.test_mutations_batcher.TestMutationsBatcherAsync - # autogen_sync_name: TestMutationsBatcher - # replace_methods: - # _get_target_class: | - # from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher - # return MutationsBatcher - # is_async: "return False" + - path: tests.unit.data._async.test__mutate_rows.TestMutateRowsOperation + autogen_sync_name: TestMutateRowsOperation + replace_methods: + _target_class: | + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation + return _MutateRowsOperation + - path: tests.unit.data._async.test__read_rows.TestReadRowsOperation + autogen_sync_name: TestReadRowsOperation + text_replacements: + test_aclose: test_close + replace_methods: + _get_target_class: | + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + return _ReadRowsOperation + - path: tests.unit.data._async.test_mutations_batcher.Test_FlowControl + autogen_sync_name: Test_FlowControl + replace_methods: + _target_class: | + from google.cloud.bigtable.data._sync.mutations_batcher import _FlowControl + return _FlowControl + - path: tests.unit.data._async.test_mutations_batcher.TestMutationsBatcherAsync + autogen_sync_name: TestMutationsBatcher + replace_methods: + _get_target_class: | + from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher + return MutationsBatcher + is_async: "return False" - path: tests.unit.data._async.test_client.TestBigtableDataClientAsync autogen_sync_name: TestBigtableDataClient added_imports: @@ -84,9 +85,23 @@ classes: return Table is_async: "return False" drop_methods: ["test_table_ctor_sync"] - #- path: tests.unit.data._async.test_client.TestReadRowsShardedAsync - # autogen_sync_name: TestReadRowsSharded - #- path: tests.unit.data._async.test_client.TestReadRowsAsync - # autogen_sync_name: TestReadRows + - path: tests.unit.data._async.test_client.TestReadRowsAsync + autogen_sync_name: TestReadRows + replace_methods: + _get_operation_class: | + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + return _ReadRowsOperation + - path: tests.unit.data._async.test_client.TestReadRowsShardedAsync + autogen_sync_name: TestReadRowsSharded + - path: tests.unit.data._async.test_client.TestSampleRowKeysAsync + autogen_sync_name: TestSampleRowKeys + - path: tests.unit.data._async.test_client.TestMutateRowAsync + autogen_sync_name: TestMutateRow + - path: tests.unit.data._async.test_client.TestBulkMutateRowsAsync + autogen_sync_name: TestBulkMutateRows + - path: tests.unit.data._async.test_client.TestCheckAndMutateRowAsync + autogen_sync_name: TestCheckAndMutateRow + - path: tests.unit.data._async.test_client.TestReadModifyWriteRowAsync + autogen_sync_name: TestReadModifyWriteRow save_path: "tests/unit/data/_sync/test_autogen.py" diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 0b5c10f58..c2bbdcbb4 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -791,7 +791,7 @@ async def test__multiple_table_registration(self): add multiple owners to instance_owners, but only keep one copy of shared key in active_instances """ - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey async with self._make_client(project="project-id") as client: async with client.get_table("instance_1", "table_1") as table_1: @@ -839,7 +839,7 @@ async def test__multiple_instance_registration(self): registering with multiple instance keys should update the key in instance_owners and active_instances """ - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey async with self._make_client(project="project-id") as client: async with client.get_table("instance_1", "table_1") as table_1: @@ -874,8 +874,7 @@ async def test__multiple_instance_registration(self): @pytest.mark.asyncio async def test_get_table(self): - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey client = self._make_client(project="project-id") assert not client._active_instances @@ -888,7 +887,7 @@ async def test_get_table(self): expected_app_profile_id, ) await asyncio.sleep(0) - assert isinstance(table, TableAsync) + assert isinstance(table, TestTableAsync._get_target_class()) assert table.table_id == expected_table_id assert ( table.table_name @@ -941,15 +940,14 @@ async def test_get_table_arg_passthrough(self): @pytest.mark.asyncio async def test_get_table_context_manager(self): - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey expected_table_id = "table-id" expected_instance_id = "instance-id" expected_app_profile_id = "app-profile-id" expected_project_id = "project-id" - with mock.patch.object(TableAsync, "close") as close_mock: + with mock.patch.object(TestTableAsync._get_target_class(), "close") as close_mock: async with self._make_client(project=expected_project_id) as client: async with client.get_table( expected_instance_id, @@ -957,7 +955,7 @@ async def test_get_table_context_manager(self): expected_app_profile_id, ) as table: await asyncio.sleep(0) - assert isinstance(table, TableAsync) + assert isinstance(table, TestTableAsync._get_target_class()) assert table.table_id == expected_table_id assert ( table.table_name @@ -1077,7 +1075,7 @@ def is_async(self): @pytest.mark.asyncio async def test_table_ctor(self): - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -1379,6 +1377,11 @@ class TestReadRowsAsync: Tests for table.read_rows and related methods. """ + @staticmethod + def _get_operation_class(): + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + return _ReadRowsOperationAsync + def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -1668,13 +1671,12 @@ async def test_read_rows_revise_request(self): """ Ensure that _revise_request is called between retries """ - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable_v2.types import RowSet return_val = RowSet() with mock.patch.object( - _ReadRowsOperationAsync, "_revise_request_rowset" + self._get_operation_class(), "_revise_request_rowset" ) as revise_rowset: revise_rowset.return_value = return_val async with self._make_table() as table: @@ -1703,11 +1705,9 @@ async def test_read_rows_default_timeouts(self): """ Ensure that the default timeouts are set on the read rows operation when not overridden """ - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - operation_timeout = 8 attempt_timeout = 4 - with mock.patch.object(_ReadRowsOperationAsync, "__init__") as mock_op: + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: mock_op.side_effect = RuntimeError("mock error") async with self._make_table( default_read_rows_operation_timeout=operation_timeout, @@ -1726,11 +1726,9 @@ async def test_read_rows_default_timeout_override(self): """ When timeouts are passed, they overwrite default values """ - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - operation_timeout = 8 attempt_timeout = 4 - with mock.patch.object(_ReadRowsOperationAsync, "__init__") as mock_op: + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: mock_op.side_effect = RuntimeError("mock error") async with self._make_table( default_operation_timeout=99, default_attempt_timeout=97 @@ -1904,9 +1902,9 @@ async def test_read_rows_sharded_multiple_queries(self): table.client._gapic_client, "read_rows" ) as read_rows: read_rows.side_effect = ( - lambda *args, **kwargs: TestReadRows._make_gapic_stream( + lambda *args, **kwargs: TestReadRowsAsync._make_gapic_stream( [ - TestReadRows._make_chunk(row_key=k) + TestReadRowsAsync._make_chunk(row_key=k) for k in args[0].rows.row_keys ] ) @@ -1989,9 +1987,7 @@ async def test_read_rows_sharded_batching(self): Large queries should be processed in batches to limit concurrency operation timeout should change between batches """ - import functools - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT assert _CONCURRENCY_LIMIT == 10 # change this test if this changes @@ -1999,23 +1995,32 @@ async def test_read_rows_sharded_batching(self): expected_num_batches = n_queries // _CONCURRENCY_LIMIT query_list = [ReadRowsQuery() for _ in range(n_queries)] - table_mock = AsyncMock() start_operation_timeout = 10 start_attempt_timeout = 3 - table_mock.default_read_rows_operation_timeout = start_operation_timeout - table_mock.default_read_rows_attempt_timeout = start_attempt_timeout - table_mock._shard_batch_helper = functools.partial(TableAsync._shard_batch_helper, table_mock) - # clock ticks one second on each check - with mock.patch("time.monotonic", side_effect=range(0, 100000)): - with mock.patch("asyncio.gather", AsyncMock()) as gather_mock: - await TableAsync.read_rows_sharded(table_mock, query_list) + + client = self._make_client(use_emulator=True) + table = client.get_table( + "instance", "table", + default_read_rows_operation_timeout=start_operation_timeout, + default_read_rows_attempt_timeout=start_attempt_timeout + ) + + # make timeout generator that reduces timeout by one each call + def mock_time_generator(start_op, _): + for i in range(0, 100000): + yield start_op - i + + with mock.patch(f"google.cloud.bigtable.data._helpers._attempt_timeout_generator") as time_gen_mock: + time_gen_mock.side_effect = mock_time_generator + + with mock.patch.object(table, "read_rows", AsyncMock()) as read_rows_mock: + read_rows_mock.return_value = [] + await table.read_rows_sharded(query_list) # should have individual calls for each query - assert table_mock.read_rows.call_count == n_queries - # should have single gather call for each batch - assert gather_mock.call_count == expected_num_batches + assert read_rows_mock.call_count == n_queries # ensure that timeouts decrease over time kwargs = [ - table_mock.read_rows.call_args_list[idx][1] + read_rows_mock.call_args_list[idx][1] for idx in range(n_queries) ] for batch_idx in range(expected_num_batches): @@ -2027,24 +2032,19 @@ async def test_read_rows_sharded_batching(self): for req_kwargs in batch_kwargs: # each batch should have the same operation_timeout, and it should decrease in each batch expected_operation_timeout = start_operation_timeout - ( - batch_idx + 1 + batch_idx ) assert ( - req_kwargs["operation_timeout"] - == expected_operation_timeout + req_kwargs["operation_timeout"] == expected_operation_timeout ) # each attempt_timeout should start with default value, but decrease when operation_timeout reaches it expected_attempt_timeout = min( start_attempt_timeout, expected_operation_timeout ) assert req_kwargs["attempt_timeout"] == expected_attempt_timeout - # await all created coroutines to avoid warnings - for i in range(len(gather_mock.call_args_list)): - for j in range(len(gather_mock.call_args_list[i][0])): - await gather_mock.call_args_list[i][0][j] -class TestSampleRowKeys: +class TestSampleRowKeysAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2196,7 +2196,7 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio await table.sample_row_keys() -class TestMutateRow: +class TestMutateRowAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2372,7 +2372,7 @@ async def test_mutate_row_no_mutations(self, mutations): assert e.value.args[0] == "No mutations provided" -class TestBulkMutateRows: +class TestBulkMutateRowsAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2752,7 +2752,7 @@ async def test_bulk_mutate_error_recovery(self): await table.bulk_mutate_rows(entries, operation_timeout=1000) -class TestCheckAndMutateRow: +class TestCheckAndMutateRowAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2904,7 +2904,7 @@ async def test_check_and_mutate_mutations_parsing(self): ) -class TestReadModifyWriteRow: +class TestReadModifyWriteRowAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py index 8baafb0d7..4aae958ed 100644 --- a/tests/unit/data/_sync/test_autogen.py +++ b/tests/unit/data/_sync/test_autogen.py @@ -17,22 +17,32 @@ from __future__ import annotations from abc import ABC +from tests.unit.data._async.test__mutate_rows import TestMutateRowsOperation +from tests.unit.data._async.test__read_rows import TestReadRowsOperation +from tests.unit.data._async.test_mutations_batcher import Test_FlowControl from unittest import mock import asyncio +import concurrent.futures import grpc import mock import pytest import re +import threading import time from google.api_core import exceptions as core_exceptions from google.api_core import grpc_helpers from google.auth.credentials import AnonymousCredentials +from google.cloud.bigtable.data import ReadRowsQuery from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data import Table from google.cloud.bigtable.data import mutations +from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule +from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable_v2 import ReadRowsResponse from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( PooledBigtableGrpcTransport, @@ -40,1119 +50,3717 @@ from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( PooledChannel, ) +from google.cloud.bigtable_v2.types import MutateRowsResponse +from google.cloud.bigtable_v2.types import ReadRowsResponse +from google.rpc import status_pb2 +import google.api_core.exceptions import google.api_core.exceptions as core_exceptions +import google.api_core.retry -class TestBigtableDataClient(ABC): - @staticmethod - def _get_target_class(): - from google.cloud.bigtable.data._sync.client import BigtableDataClient +class TestMutateRowsOperation(ABC): + def _target_class(self): + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation - return BigtableDataClient + return _MutateRowsOperation - @classmethod - def _make_client(cls, *args, use_emulator=True, **kwargs): - import os + def _make_one(self, *args, **kwargs): + if not args: + kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) + kwargs["table"] = kwargs.pop("table", mock.Mock()) + kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) + kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) + kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) + kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) + return self._target_class()(*args, **kwargs) - env_mask = {} - if use_emulator: - env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" - import warnings + def _make_mutation(self, count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation - warnings.filterwarnings("ignore", category=RuntimeWarning) - else: - kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) - kwargs["project"] = kwargs.get("project", "project-id") - with mock.patch.dict(os.environ, env_mask): - return cls._get_target_class()(*args, **kwargs) + def _mock_stream(self, mutation_list, error_dict): + for idx, entry in enumerate(mutation_list): + code = error_dict.get(idx, 0) + yield MutateRowsResponse( + entries=[ + MutateRowsResponse.Entry( + index=idx, status=status_pb2.Status(code=code) + ) + ] + ) - @property - def is_async(self): - return False + def _make_mock_gapic(self, mutation_list, error_dict=None): + mock_fn = mock.Mock() + if error_dict is None: + error_dict = {} + mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( + mutation_list, error_dict + ) + return mock_fn def test_ctor(self): - expected_project = "project-id" - expected_pool_size = 11 - expected_credentials = AnonymousCredentials() - client = self._make_client( - project="project-id", - pool_size=expected_pool_size, - credentials=expected_credentials, - use_emulator=False, + """test that constructor sets all the attributes correctly""" + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + from google.api_core.exceptions import DeadlineExceeded + from google.api_core.exceptions import Aborted + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + attempt_timeout = 0.01 + retryable_exceptions = () + instance = self._make_one( + client, + table, + entries, + operation_timeout, + attempt_timeout, + retryable_exceptions, ) - time.sleep(0) - assert client.project == expected_project - assert len(client.transport._grpc_channel._pool) == expected_pool_size - assert not client._active_instances - assert len(client._channel_refresh_tasks) == expected_pool_size - assert client.transport._credentials == expected_credentials - client.close() + assert client.mutate_rows.call_count == 0 + instance._gapic_fn() + assert client.mutate_rows.call_count == 1 + inner_kwargs = client.mutate_rows.call_args[1] + assert len(inner_kwargs) == 4 + assert inner_kwargs["table_name"] == table.table_name + assert inner_kwargs["app_profile_id"] == table.app_profile_id + assert inner_kwargs["retry"] is None + metadata = inner_kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert str(table.table_name) in metadata[0][1] + assert str(table.app_profile_id) in metadata[0][1] + entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries] + assert instance.mutations == entries_w_pb + assert next(instance.timeout_generator) == attempt_timeout + assert instance.is_retryable is not None + assert instance.is_retryable(DeadlineExceeded("")) is False + assert instance.is_retryable(Aborted("")) is False + assert instance.is_retryable(_MutateRowsIncomplete("")) is True + assert instance.is_retryable(RuntimeError("")) is False + assert instance.remaining_indices == list(range(len(entries))) + assert instance.errors == {} - def test_ctor_super_inits(self): - from google.cloud.client import ClientWithProject - from google.api_core import client_options as client_options_lib - from google.cloud.bigtable import __version__ as bigtable_version + def test_ctor_too_many_entries(self): + """should raise an error if an operation is created with more than 100,000 entries""" + from google.cloud.bigtable.data._async._mutate_rows import ( + _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, + ) - project = "project-id" - pool_size = 11 - credentials = AnonymousCredentials() - client_options = {"api_endpoint": "foo.bar:1234"} - options_parsed = client_options_lib.from_dict(client_options) - asyncio_portion = "-async" if self.is_async else "" - transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" - with mock.patch.object(BigtableClient, "__init__") as bigtable_client_init: - bigtable_client_init.return_value = None - with mock.patch.object( - ClientWithProject, "__init__" - ) as client_project_init: - client_project_init.return_value = None - try: - self._make_client( - project=project, - pool_size=pool_size, - credentials=credentials, - client_options=options_parsed, - use_emulator=False, - ) - except AttributeError: - pass - assert bigtable_client_init.call_count == 1 - kwargs = bigtable_client_init.call_args[1] - assert kwargs["transport"] == transport_str - assert kwargs["credentials"] == credentials - assert kwargs["client_options"] == options_parsed - assert client_project_init.call_count == 1 - kwargs = client_project_init.call_args[1] - assert kwargs["project"] == project - assert kwargs["credentials"] == credentials - assert kwargs["client_options"] == options_parsed + assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100000 + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + operation_timeout = 0.05 + attempt_timeout = 0.01 + self._make_one(client, table, entries, operation_timeout, attempt_timeout) + with pytest.raises(ValueError) as e: + self._make_one( + client, + table, + entries + [self._make_mutation()], + operation_timeout, + attempt_timeout, + ) + assert "mutate_rows requests can contain at most 100000 mutations" in str( + e.value + ) + assert "Found 100001" in str(e.value) - def test_ctor_dict_options(self): - from google.api_core.client_options import ClientOptions + def test_mutate_rows_operation(self): + """Test successful case of mutate_rows_operation""" + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + cls = self._target_class() + with mock.patch( + f"{cls.__module__}.{cls.__name__}._run_attempt", mock.Mock() + ) as attempt_mock: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + assert attempt_mock.call_count == 1 - client_options = {"api_endpoint": "foo.bar:1234"} - with mock.patch.object(BigtableClient, "__init__") as bigtable_client_init: + @pytest.mark.parametrize( + "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] + ) + def test_mutate_rows_attempt_exception(self, exc_type): + """exceptions raised from attempt should be raised in MutationsExceptionGroup""" + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + expected_exception = exc_type("test") + client.mutate_rows.side_effect = expected_exception + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance._run_attempt() + except Exception as e: + found_exc = e + assert client.mutate_rows.call_count == 1 + assert type(found_exc) is exc_type + assert found_exc == expected_exception + assert len(instance.errors) == 2 + assert len(instance.remaining_indices) == 0 + + @pytest.mark.parametrize( + "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] + ) + def test_mutate_rows_exception(self, exc_type): + """exceptions raised from retryable should be raised in MutationsExceptionGroup""" + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + expected_cause = exc_type("abort") + with mock.patch.object( + self._target_class(), "_run_attempt", mock.Mock() + ) as attempt_mock: + attempt_mock.side_effect = expected_cause + found_exc = None try: - self._make_client(client_options=client_options) - except TypeError: - pass - bigtable_client_init.assert_called_once() - kwargs = bigtable_client_init.call_args[1] - called_options = kwargs["client_options"] - assert called_options.api_endpoint == "foo.bar:1234" - assert isinstance(called_options, ClientOptions) + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + except MutationsExceptionGroup as e: + found_exc = e + assert attempt_mock.call_count == 1 + assert len(found_exc.exceptions) == 2 + assert isinstance(found_exc.exceptions[0], FailedMutationEntryError) + assert isinstance(found_exc.exceptions[1], FailedMutationEntryError) + assert found_exc.exceptions[0].__cause__ == expected_cause + assert found_exc.exceptions[1].__cause__ == expected_cause + + @pytest.mark.parametrize( + "exc_type", [core_exceptions.DeadlineExceeded, RuntimeError] + ) + def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): + """If an exception fails but eventually passes, it should not raise an exception""" + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] + operation_timeout = 1 + expected_cause = exc_type("retry") + num_retries = 2 with mock.patch.object( - self._get_target_class(), "_start_background_channel_refresh" - ) as start_background_refresh: - client = self._make_client( - client_options=client_options, use_emulator=False + self._target_class(), "_run_attempt", mock.Mock() + ) as attempt_mock: + attempt_mock.side_effect = [expected_cause] * num_retries + [None] + instance = self._make_one( + client, + table, + entries, + operation_timeout, + operation_timeout, + retryable_exceptions=(exc_type,), ) - start_background_refresh.assert_called_once() - client.close() + instance.start() + assert attempt_mock.call_count == num_retries + 1 - def test_veneer_grpc_headers(self): - client_component = "data-async" if self.is_async else "data" - VENEER_HEADER_REGEX = re.compile( - "gapic\\/[0-9]+\\.[\\w.-]+ gax\\/[0-9]+\\.[\\w.-]+ gccl\\/[0-9]+\\.[\\w.-]+-" - + client_component - + " gl-python\\/[0-9]+\\.[\\w.-]+ grpc\\/[0-9]+\\.[\\w.-]+" - ) - if self.is_async: - patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") - else: - patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") - with patch as gapic_mock: - client = self._make_client(project="project-id") - wrapped_call_list = gapic_mock.call_args_list - assert len(wrapped_call_list) > 0 - for call in wrapped_call_list: - client_info = call.kwargs["client_info"] - assert client_info is not None, f"{call} has no client_info" - wrapped_user_agent_sorted = " ".join( - sorted(client_info.to_user_agent().split(" ")) - ) - assert VENEER_HEADER_REGEX.match( - wrapped_user_agent_sorted - ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" - client.close() + def test_mutate_rows_incomplete_ignored(self): + """MutateRowsIncomplete exceptions should not be added to error list""" + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded - def test_channel_pool_creation(self): - pool_size = 14 + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] + operation_timeout = 0.05 with mock.patch.object( - grpc_helpers, "create_channel", mock.Mock() - ) as create_channel: - client = self._make_client(project="project-id", pool_size=pool_size) - assert create_channel.call_count == pool_size - client.close() - client = self._make_client(project="project-id", pool_size=pool_size) - pool_list = list(client.transport._grpc_channel._pool) - pool_set = set(client.transport._grpc_channel._pool) - assert len(pool_list) == len(pool_set) - client.close() + self._target_class(), "_run_attempt", mock.Mock() + ) as attempt_mock: + attempt_mock.side_effect = _MutateRowsIncomplete("ignored") + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + except MutationsExceptionGroup as e: + found_exc = e + assert attempt_mock.call_count > 0 + assert len(found_exc.exceptions) == 1 + assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) - def test_channel_pool_rotation(self): - pool_size = 7 - with mock.patch.object(PooledChannel, "next_channel") as next_channel: - client = self._make_client(project="project-id", pool_size=pool_size) - assert len(client.transport._grpc_channel._pool) == pool_size - next_channel.reset_mock() - with mock.patch.object( - type(client.transport._grpc_channel._pool[0]), "unary_unary" - ) as unary_unary: - channel_next = None - for i in range(pool_size): - channel_last = channel_next - channel_next = client.transport.grpc_channel._pool[i] - assert channel_last != channel_next - next_channel.return_value = channel_next - client.transport.ping_and_warm() - assert next_channel.call_count == i + 1 - unary_unary.assert_called_once() - unary_unary.reset_mock() - client.close() + def test_run_attempt_single_entry_success(self): + """Test mutating a single entry""" + mutation = self._make_mutation() + expected_timeout = 1.3 + mock_gapic_fn = self._make_mock_gapic({0: mutation}) + instance = self._make_one( + mutation_entries=[mutation], attempt_timeout=expected_timeout + ) + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + instance._run_attempt() + assert len(instance.remaining_indices) == 0 + assert mock_gapic_fn.call_count == 1 + (_, kwargs) = mock_gapic_fn.call_args + assert kwargs["timeout"] == expected_timeout + assert kwargs["entries"] == [mutation._to_pb()] - def test_channel_pool_replace(self): - import time + def test_run_attempt_empty_request(self): + """Calling with no mutations should result in no API calls""" + mock_gapic_fn = self._make_mock_gapic([]) + instance = self._make_one(mutation_entries=[]) + instance._run_attempt() + assert mock_gapic_fn.call_count == 0 - sleep_module = asyncio if self.is_async else time - with mock.patch.object(sleep_module, "sleep"): - pool_size = 7 - client = self._make_client(project="project-id", pool_size=pool_size) - for replace_idx in range(pool_size): - start_pool = [ - channel for channel in client.transport._grpc_channel._pool - ] - grace_period = 9 - with mock.patch.object( - type(client.transport._grpc_channel._pool[-1]), "close" - ) as close: - new_channel = client.transport.create_channel() - client.transport.replace_channel( - replace_idx, grace=grace_period, new_channel=new_channel - ) - close.assert_called_once() - if self.is_async: - close.assert_called_once_with(grace=grace_period) - close.assert_called_once() - assert client.transport._grpc_channel._pool[replace_idx] == new_channel - for i in range(pool_size): - if i != replace_idx: - assert client.transport._grpc_channel._pool[i] == start_pool[i] - else: - assert client.transport._grpc_channel._pool[i] != start_pool[i] - client.close() + def test_run_attempt_partial_success_retryable(self): + """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - def test__start_background_channel_refresh_tasks_exist(self): - client = self._make_client(project="project-id", use_emulator=False) - assert len(client._channel_refresh_tasks) > 0 - with mock.patch.object(asyncio, "create_task") as create_task: - client._start_background_channel_refresh() - create_task.assert_not_called() - client.close() + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() + mutations = [success_mutation, failure_mutation, success_mutation_2] + mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) + instance = self._make_one(mutation_entries=mutations) + instance.is_retryable = lambda x: True + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + with pytest.raises(_MutateRowsIncomplete): + instance._run_attempt() + assert instance.remaining_indices == [1] + assert 0 not in instance.errors + assert len(instance.errors[1]) == 1 + assert instance.errors[1][0].grpc_status_code == 300 + assert 2 not in instance.errors - @pytest.mark.parametrize("pool_size", [1, 3, 7]) - def test__start_background_channel_refresh(self, pool_size): - import concurrent.futures + def test_run_attempt_partial_success_non_retryable(self): + """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() + mutations = [success_mutation, failure_mutation, success_mutation_2] + mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) + instance = self._make_one(mutation_entries=mutations) + instance.is_retryable = lambda x: False + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + instance._run_attempt() + assert instance.remaining_indices == [] + assert 0 not in instance.errors + assert len(instance.errors[1]) == 1 + assert instance.errors[1][0].grpc_status_code == 300 + assert 2 not in instance.errors - with mock.patch.object( - self._get_target_class(), "_ping_and_warm_instances", mock.Mock() - ) as ping_and_warm: - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - client._start_background_channel_refresh() - assert len(client._channel_refresh_tasks) == pool_size - for task in client._channel_refresh_tasks: - if self.is_async: - assert isinstance(task, asyncio.Task) - else: - assert isinstance(task, concurrent.futures.Future) - time.sleep(0.1) - assert ping_and_warm.call_count == pool_size - for channel in client.transport._grpc_channel._pool: - ping_and_warm.assert_any_call(channel) - client.close() - def test__ping_and_warm_instances(self): - """test ping and warm with mocked asyncio.gather""" - client_mock = mock.Mock() - client_mock._execute_ping_and_warms = ( - lambda *args: self._get_target_class()._execute_ping_and_warms( - client_mock, *args +class TestReadRowsOperation(ABC): + """ + Tests helper functions in the ReadRowsOperation class + in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt + is tested in test_read_rows_acceptance test_client_read_rows, and conformance tests + """ + + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + + return _ReadRowsOperation + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor(self): + from google.cloud.bigtable.data import ReadRowsQuery + + row_limit = 91 + query = ReadRowsQuery(limit=row_limit) + client = mock.Mock() + client.read_rows = mock.Mock() + client.read_rows.return_value = None + table = mock.Mock() + table._client = client + table.table_name = "test_table" + table.app_profile_id = "test_profile" + expected_operation_timeout = 42 + expected_request_timeout = 44 + time_gen_mock = mock.Mock() + with mock.patch( + "google.cloud.bigtable.data._helpers._attempt_timeout_generator", + time_gen_mock, + ): + instance = self._make_one( + query, + table, + operation_timeout=expected_operation_timeout, + attempt_timeout=expected_request_timeout, ) + assert time_gen_mock.call_count == 1 + time_gen_mock.assert_called_once_with( + expected_request_timeout, expected_operation_timeout ) - gather_tuple = ( - (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") - ) - with mock.patch.object(*gather_tuple, mock.Mock()) as gather: - if self.is_async: - gather.side_effect = lambda *args, **kwargs: [None for _ in args] - else: - gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] - channel = mock.Mock() - client_mock._active_instances = [] - result = self._get_target_class()._ping_and_warm_instances( - client_mock, channel - ) - assert len(result) == 0 - if self.is_async: - assert gather.call_args.kwargs == {"return_exceptions": True} - client_mock._active_instances = [ - (mock.Mock(), mock.Mock(), mock.Mock()) - ] * 4 - gather.reset_mock() - channel.reset_mock() - result = self._get_target_class()._ping_and_warm_instances( - client_mock, channel + assert instance._last_yielded_row_key is None + assert instance._remaining_count == row_limit + assert instance.operation_timeout == expected_operation_timeout + assert client.read_rows.call_count == 0 + assert instance._metadata == [ + ( + "x-goog-request-params", + "table_name=test_table&app_profile_id=test_profile", ) - assert len(result) == 4 - if self.is_async: - gather.assert_called_once() - gather.assert_called_once() - assert len(gather.call_args.args) == 4 - else: - assert gather.call_count == 4 - grpc_call_args = channel.unary_unary().call_args_list - for idx, (_, kwargs) in enumerate(grpc_call_args): - ( - expected_instance, - expected_table, - expected_app_profile, - ) = client_mock._active_instances[idx] - request = kwargs["request"] - assert request["name"] == expected_instance - assert request["app_profile_id"] == expected_app_profile - metadata = kwargs["metadata"] - assert len(metadata) == 1 - assert metadata[0][0] == "x-goog-request-params" - assert ( - metadata[0][1] - == f"name={expected_instance}&app_profile_id={expected_app_profile}" - ) + ] + assert instance.request.table_name == table.table_name + assert instance.request.app_profile_id == table.app_profile_id + assert instance.request.rows_limit == row_limit - def test__ping_and_warm_single_instance(self): - """should be able to call ping and warm with single instance""" - client_mock = mock.Mock() - client_mock._execute_ping_and_warms = ( - lambda *args: self._get_target_class()._execute_ping_and_warms( - client_mock, *args - ) - ) - gather_tuple = ( - (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") - ) - with mock.patch.object(*gather_tuple, mock.Mock()) as gather: - gather.side_effect = lambda *args, **kwargs: [mock.Mock() for _ in args] - if self.is_async: - gather.side_effect = lambda *args, **kwargs: [None for _ in args] - else: - gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] - channel = mock.Mock() - client_mock._active_instances = [mock.Mock()] * 100 - test_key = ("test-instance", "test-table", "test-app-profile") - result = self._get_target_class()._ping_and_warm_instances( - client_mock, channel, test_key - ) - assert len(result) == 1 - grpc_call_args = channel.unary_unary().call_args_list - assert len(grpc_call_args) == 1 - kwargs = grpc_call_args[0][1] - request = kwargs["request"] - assert request["name"] == "test-instance" - assert request["app_profile_id"] == "test-app-profile" - metadata = kwargs["metadata"] - assert len(metadata) == 1 - assert metadata[0][0] == "x-goog-request-params" - assert ( - metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" - ) + @pytest.mark.parametrize( + "in_keys,last_key,expected", + [ + (["b", "c", "d"], "a", ["b", "c", "d"]), + (["a", "b", "c"], "b", ["c"]), + (["a", "b", "c"], "c", []), + (["a", "b", "c"], "d", []), + (["d", "c", "b", "a"], "b", ["d", "c"]), + ], + ) + def test_revise_request_rowset_keys(self, in_keys, last_key, expected): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + in_keys = [key.encode("utf-8") for key in in_keys] + expected = [key.encode("utf-8") for key in expected] + last_key = last_key.encode("utf-8") + sample_range = RowRangePB(start_key_open=last_key) + row_set = RowSetPB(row_keys=in_keys, row_ranges=[sample_range]) + revised = self._get_target_class()._revise_request_rowset(row_set, last_key) + assert revised.row_keys == expected + assert revised.row_ranges == [sample_range] @pytest.mark.parametrize( - "refresh_interval, wait_time, expected_sleep", - [(0, 0, 0), (0, 1, 0), (10, 0, 10), (10, 5, 5), (10, 10, 0), (10, 15, 0)], + "in_ranges,last_key,expected", + [ + ( + [{"start_key_open": "b", "end_key_closed": "d"}], + "a", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_closed": "b", "end_key_closed": "d"}], + "a", + [{"start_key_closed": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_open": "a", "end_key_closed": "d"}], + "b", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_closed": "a", "end_key_open": "d"}], + "b", + [{"start_key_open": "b", "end_key_open": "d"}], + ), + ( + [{"start_key_closed": "b", "end_key_closed": "d"}], + "b", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ([{"start_key_closed": "b", "end_key_closed": "d"}], "d", []), + ([{"start_key_closed": "b", "end_key_open": "d"}], "d", []), + ([{"start_key_closed": "b", "end_key_closed": "d"}], "e", []), + ([{"start_key_closed": "b"}], "z", [{"start_key_open": "z"}]), + ([{"start_key_closed": "b"}], "a", [{"start_key_closed": "b"}]), + ( + [{"end_key_closed": "z"}], + "a", + [{"start_key_open": "a", "end_key_closed": "z"}], + ), + ( + [{"end_key_open": "z"}], + "a", + [{"start_key_open": "a", "end_key_open": "z"}], + ), + ], ) - def test__manage_channel_first_sleep( - self, refresh_interval, wait_time, expected_sleep - ): - import threading - import time + def test_revise_request_rowset_ranges(self, in_ranges, last_key, expected): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB - with mock.patch.object(time, "monotonic") as monotonic: - monotonic.return_value = 0 - sleep_tuple = ( - (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + next_key = (last_key + "a").encode("utf-8") + last_key = last_key.encode("utf-8") + in_ranges = [ + RowRangePB(**{k: v.encode("utf-8") for (k, v) in r.items()}) + for r in in_ranges + ] + expected = [ + RowRangePB(**{k: v.encode("utf-8") for (k, v) in r.items()}) + for r in expected + ] + row_set = RowSetPB(row_ranges=in_ranges, row_keys=[next_key]) + revised = self._get_target_class()._revise_request_rowset(row_set, last_key) + assert revised.row_keys == [next_key] + assert revised.row_ranges == expected + + @pytest.mark.parametrize("last_key", ["a", "b", "c"]) + def test_revise_request_full_table(self, last_key): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + last_key = last_key.encode("utf-8") + row_set = RowSetPB() + for selected_set in [row_set, None]: + revised = self._get_target_class()._revise_request_rowset( + selected_set, last_key ) - with mock.patch.object(*sleep_tuple) as sleep: - sleep.side_effect = asyncio.CancelledError - try: - client = self._make_client(project="project-id") - client._channel_init_time = -wait_time - client._manage_channel(0, refresh_interval, refresh_interval) - except asyncio.CancelledError: - pass - sleep.assert_called_once() - call_time = sleep.call_args[0][0] - assert ( - abs(call_time - expected_sleep) < 0.1 - ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" - client.close() + assert revised.row_keys == [] + assert len(revised.row_ranges) == 1 + assert revised.row_ranges[0] == RowRangePB(start_key_open=last_key) - def test__manage_channel_ping_and_warm(self): - """_manage channel should call ping and warm internally""" - import time - import threading + def test_revise_to_empty_rowset(self): + """revising to an empty rowset should raise error""" + from google.cloud.bigtable.data.exceptions import _RowSetComplete + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB - client_mock = mock.Mock() - client_mock._is_closed.is_set.return_value = False - client_mock._channel_init_time = time.monotonic() - channel_list = [mock.Mock(), mock.Mock()] - client_mock.transport.channels = channel_list - new_channel = mock.Mock() - client_mock.transport.grpc_channel._create_channel.return_value = new_channel - sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") - with mock.patch.object(*sleep_tuple): - client_mock.transport.replace_channel.side_effect = asyncio.CancelledError - ping_and_warm = client_mock._ping_and_warm_instances = mock.Mock() - try: - channel_idx = 1 - self._get_target_class()._manage_channel(client_mock, channel_idx, 10) - except asyncio.CancelledError: - pass - assert ping_and_warm.call_count == 2 - assert client_mock.transport.replace_channel.call_count == 1 - old_channel = channel_list[channel_idx] - assert old_channel != new_channel - called_with = [call[0][0] for call in ping_and_warm.call_args_list] - assert old_channel in called_with - assert new_channel in called_with - ping_and_warm.reset_mock() - try: - self._get_target_class()._manage_channel(client_mock, 0, 0, 0) - except asyncio.CancelledError: - pass - ping_and_warm.assert_called_once_with(new_channel) + row_keys = [b"a", b"b", b"c"] + row_range = RowRangePB(end_key_open=b"c") + row_set = RowSetPB(row_keys=row_keys, row_ranges=[row_range]) + with pytest.raises(_RowSetComplete): + self._get_target_class()._revise_request_rowset(row_set, b"d") @pytest.mark.parametrize( - "refresh_interval, num_cycles, expected_sleep", - [(None, 1, 60 * 35), (10, 10, 100), (10, 1, 10)], + "start_limit,emit_num,expected_limit", + [ + (10, 0, 10), + (10, 1, 9), + (10, 10, 0), + (None, 10, None), + (None, 0, None), + (4, 2, 2), + ], ) - def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sleep): - import time - import random - import threading + def test_revise_limit(self, start_limit, emit_num, expected_limit): + """ + revise_limit should revise the request's limit field + - if limit is 0 (unlimited), it should never be revised + - if start_limit-emit_num == 0, the request should end early + - if the number emitted exceeds the new limit, an exception should + should be raised (tested in test_revise_limit_over_limit) + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse - channel_idx = 1 - with mock.patch.object(random, "uniform") as uniform: - uniform.side_effect = lambda min_, max_: min_ - with mock.patch.object(time, "time") as time_mock: - time_mock.return_value = 0 - sleep_tuple = ( - (asyncio, "sleep") if self.is_async else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: - sleep.side_effect = [None for i in range(num_cycles - 1)] + [ - asyncio.CancelledError - ] - client = self._make_client(project="project-id") - with mock.patch.object(client.transport, "replace_channel"): - try: - if refresh_interval is not None: - client._manage_channel( - channel_idx, refresh_interval, refresh_interval - ) - else: - client._manage_channel(channel_idx) - except asyncio.CancelledError: - pass - assert sleep.call_count == num_cycles - total_sleep = sum([call[0][0] for call in sleep.call_args_list]) - assert ( - abs(total_sleep - expected_sleep) < 0.1 - ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" - client.close() + def awaitable_stream(): + def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) - def test__manage_channel_random(self): - import random - import threading + return mock_stream() - sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") - with mock.patch.object(*sleep_tuple) as sleep: - with mock.patch.object(random, "uniform") as uniform: - uniform.return_value = 0 - try: - uniform.side_effect = asyncio.CancelledError - client = self._make_client(project="project-id", pool_size=1) - except asyncio.CancelledError: - uniform.side_effect = None - uniform.reset_mock() - sleep.reset_mock() - min_val = 200 - max_val = 205 - uniform.side_effect = lambda min_, max_: min_ - sleep.side_effect = [None, None, asyncio.CancelledError] - try: - with mock.patch.object(client.transport, "replace_channel"): - client._manage_channel(0, min_val, max_val) - except asyncio.CancelledError: - pass - assert uniform.call_count == 3 - uniform_args = [call[0] for call in uniform.call_args_list] - for found_min, found_max in uniform_args: - assert found_min == min_val - assert found_max == max_val + query = ReadRowsQuery(limit=start_limit) + table = mock.Mock() + table.table_name = "table_name" + table.app_profile_id = "app_profile_id" + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + for val in instance.chunk_stream(awaitable_stream()): + pass + assert instance._remaining_count == expected_limit - @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) - def test__manage_channel_refresh(self, num_cycles): - import threading + @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) + def test_revise_limit_over_limit(self, start_limit, emit_num): + """ + Should raise runtime error if we get in state where emit_num > start_num + (unless start_num == 0, which represents unlimited) + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse + from google.cloud.bigtable.data.exceptions import InvalidChunk + + def awaitable_stream(): + def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) + + return mock_stream() + + query = ReadRowsQuery(limit=start_limit) + table = mock.Mock() + table.table_name = "table_name" + table.app_profile_id = "app_profile_id" + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + with pytest.raises(InvalidChunk) as e: + for val in instance.chunk_stream(awaitable_stream()): + pass + assert "emit count exceeds row limit" in str(e.value) + + def test_close(self): + """ + should be able to close a stream safely with aclose. + Closed generators should raise StopIteration on next yield + """ + + def mock_stream(): + while True: + yield 1 - expected_grace = 9 - expected_refresh = 0.5 - channel_idx = 1 - grpc_lib = grpc.aio if self.is_async else grpc - new_channel = grpc_lib.insecure_channel("localhost:8080") with mock.patch.object( - PooledBigtableGrpcTransport, "replace_channel" - ) as replace_channel: - sleep_tuple = ( - (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + self._get_target_class(), "_read_rows_attempt" + ) as mock_attempt: + instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) + wrapped_gen = mock_stream() + mock_attempt.return_value = wrapped_gen + gen = instance.start_operation() + gen.__next__() + gen.close() + with pytest.raises(StopIteration): + gen.__next__() + gen.close() + with pytest.raises(StopIteration): + wrapped_gen.__next__() + + def test_retryable_ignore_repeated_rows(self): + """Duplicate rows should cause an invalid chunk error""" + from google.cloud.bigtable.data.exceptions import InvalidChunk + from google.cloud.bigtable_v2.types import ReadRowsResponse + + row_key = b"duplicate" + + def mock_awaitable_stream(): + def mock_stream(): + while True: + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) + ] + ) + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) + ] + ) + + return mock_stream() + + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + stream = self._get_target_class().chunk_stream( + instance, mock_awaitable_stream() + ) + stream.__next__() + with pytest.raises(InvalidChunk) as exc: + stream.__next__() + assert "row keys should be strictly increasing" in str(exc.value) + + +class Test_FlowControl(ABC): + @staticmethod + def _target_class(): + from google.cloud.bigtable.data._sync.mutations_batcher import _FlowControl + + return _FlowControl + + def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): + return self._target_class()(max_mutation_count, max_mutation_bytes) + + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def test_ctor(self): + max_mutation_count = 9 + max_mutation_bytes = 19 + instance = self._make_one(max_mutation_count, max_mutation_bytes) + assert instance._max_mutation_count == max_mutation_count + assert instance._max_mutation_bytes == max_mutation_bytes + assert instance._in_flight_mutation_count == 0 + assert instance._in_flight_mutation_bytes == 0 + assert isinstance(instance._capacity_condition, threading.Condition) + + def test_ctor_invalid_values(self): + """Test that values are positive, and fit within expected limits""" + with pytest.raises(ValueError) as e: + self._make_one(0, 1) + assert "max_mutation_count must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._make_one(1, 0) + assert "max_mutation_bytes must be greater than 0" in str(e.value) + + @pytest.mark.parametrize( + "max_count,max_size,existing_count,existing_size,new_count,new_size,expected", + [ + (1, 1, 0, 0, 0, 0, True), + (1, 1, 1, 1, 1, 1, False), + (10, 10, 0, 0, 0, 0, True), + (10, 10, 0, 0, 9, 9, True), + (10, 10, 0, 0, 11, 9, True), + (10, 10, 0, 1, 11, 9, True), + (10, 10, 1, 0, 11, 9, False), + (10, 10, 0, 0, 9, 11, True), + (10, 10, 1, 0, 9, 11, True), + (10, 10, 0, 1, 9, 11, False), + (10, 1, 0, 0, 1, 0, True), + (1, 10, 0, 0, 0, 8, True), + (float("inf"), float("inf"), 0, 0, 10000000000.0, 10000000000.0, True), + (8, 8, 0, 0, 10000000000.0, 10000000000.0, True), + (12, 12, 6, 6, 5, 5, True), + (12, 12, 5, 5, 6, 6, True), + (12, 12, 6, 6, 6, 6, True), + (12, 12, 6, 6, 7, 7, False), + (12, 12, 0, 0, 13, 13, True), + (12, 12, 12, 0, 0, 13, True), + (12, 12, 0, 12, 13, 0, True), + (12, 12, 1, 1, 13, 13, False), + (12, 12, 1, 1, 0, 13, False), + (12, 12, 1, 1, 13, 0, False), + ], + ) + def test__has_capacity( + self, + max_count, + max_size, + existing_count, + existing_size, + new_count, + new_size, + expected, + ): + """_has_capacity should return True if the new mutation will will not exceed the max count or size""" + instance = self._make_one(max_count, max_size) + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + assert instance._has_capacity(new_count, new_size) == expected + + @pytest.mark.parametrize( + "existing_count,existing_size,added_count,added_size,new_count,new_size", + [ + (0, 0, 0, 0, 0, 0), + (2, 2, 1, 1, 1, 1), + (2, 0, 1, 0, 1, 0), + (0, 2, 0, 1, 0, 1), + (10, 10, 0, 0, 10, 10), + (10, 10, 5, 5, 5, 5), + (0, 0, 1, 1, -1, -1), + ], + ) + def test_remove_from_flow_value_update( + self, + existing_count, + existing_size, + added_count, + added_size, + new_count, + new_size, + ): + """completed mutations should lower the inflight values""" + instance = self._make_one() + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + mutation = self._make_mutation(added_count, added_size) + instance.remove_from_flow(mutation) + assert instance._in_flight_mutation_count == new_count + assert instance._in_flight_mutation_bytes == new_size + + def test__remove_from_flow_unlock(self): + """capacity condition should notify after mutation is complete""" + import inspect + + instance = self._make_one(10, 10) + instance._in_flight_mutation_count = 10 + instance._in_flight_mutation_bytes = 10 + + def task_routine(): + with instance._capacity_condition: + instance._capacity_condition.wait_for( + lambda: instance._has_capacity(1, 1) + ) + + if inspect.iscoroutinefunction(task_routine): + task = threading.Thread(task_routine()) + task_alive = lambda: not task.done() + else: + import threading + + thread = threading.Thread(target=task_routine) + thread.start() + task_alive = thread.is_alive + time.sleep(0.05) + assert task_alive() is True + mutation = self._make_mutation(count=0, size=5) + instance.remove_from_flow([mutation]) + time.sleep(0.05) + assert instance._in_flight_mutation_count == 10 + assert instance._in_flight_mutation_bytes == 5 + assert task_alive() is True + instance._in_flight_mutation_bytes = 10 + mutation = self._make_mutation(count=5, size=0) + instance.remove_from_flow([mutation]) + time.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 10 + assert task_alive() is True + instance._in_flight_mutation_count = 10 + mutation = self._make_mutation(count=5, size=5) + instance.remove_from_flow([mutation]) + time.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 5 + assert task_alive() is False + + @pytest.mark.parametrize( + "mutations,count_cap,size_cap,expected_results", + [ + ([(5, 5), (1, 1), (1, 1)], 10, 10, [[(5, 5), (1, 1), (1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 1, 1, [[(1, 1)], [(1, 1)], [(1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 2, 10, [[(1, 1), (1, 1)], [(1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 10, 2, [[(1, 1), (1, 1)], [(1, 1)]]), + ( + [(1, 1), (5, 5), (4, 1), (1, 4), (1, 1)], + 5, + 5, + [[(1, 1)], [(5, 5)], [(4, 1), (1, 4)], [(1, 1)]], + ), + ], + ) + def test_add_to_flow(self, mutations, count_cap, size_cap, expected_results): + """Test batching with various flow control settings""" + mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] + instance = self._make_one(count_cap, size_cap) + i = 0 + for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + assert len(batch[j].mutations) == expected_batch[j][0] + assert batch[j].size() == expected_batch[j][1] + instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + @pytest.mark.parametrize( + "mutations,max_limit,expected_results", + [ + ([(1, 1)] * 11, 10, [[(1, 1)] * 10, [(1, 1)]]), + ([(1, 1)] * 10, 1, [[(1, 1)] for _ in range(10)]), + ([(1, 1)] * 10, 2, [[(1, 1), (1, 1)] for _ in range(5)]), + ], + ) + def test_add_to_flow_max_mutation_limits( + self, mutations, max_limit, expected_results + ): + """ + Test flow control running up against the max API limit + Should submit request early, even if the flow control has room for more + """ + async_patch = mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + max_limit, + ) + sync_patch = mock.patch( + "google.cloud.bigtable.data._sync._autogen._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + max_limit, + ) + with async_patch, sync_patch: + mutation_objs = [ + self._make_mutation(count=m[0], size=m[1]) for m in mutations + ] + instance = self._make_one(float("inf"), float("inf")) + i = 0 + for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + assert len(batch[j].mutations) == expected_batch[j][0] + assert batch[j].size() == expected_batch[j][1] + instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + def test_add_to_flow_oversize(self): + """mutations over the flow control limits should still be accepted""" + instance = self._make_one(2, 3) + large_size_mutation = self._make_mutation(count=1, size=10) + large_count_mutation = self._make_mutation(count=10, size=1) + results = [out for out in instance.add_to_flow([large_size_mutation])] + assert len(results) == 1 + instance.remove_from_flow(results[0]) + count_results = [out for out in instance.add_to_flow(large_count_mutation)] + assert len(count_results) == 1 + + +class TestMutationsBatcher(ABC): + def _get_target_class(self): + from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher + + return MutationsBatcher + + @staticmethod + def is_async(): + return False + + def _make_one(self, table=None, **kwargs): + from google.api_core.exceptions import DeadlineExceeded + from google.api_core.exceptions import ServiceUnavailable + + if table is None: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 10 + table.default_mutate_rows_retryable_errors = ( + DeadlineExceeded, + ServiceUnavailable, ) - with mock.patch.object(*sleep_tuple) as sleep: - sleep.side_effect = [None for i in range(num_cycles)] + [ - asyncio.CancelledError - ] - with mock.patch.object( - grpc_helpers, "create_channel" - ) as create_channel: - create_channel.return_value = new_channel - with mock.patch.object( - self._get_target_class(), "_start_background_channel_refresh" - ): - client = self._make_client( - project="project-id", use_emulator=False - ) - create_channel.reset_mock() - try: - client._manage_channel( - channel_idx, - refresh_interval_min=expected_refresh, - refresh_interval_max=expected_refresh, - grace_period=expected_grace, - ) - except asyncio.CancelledError: - pass - assert sleep.call_count == num_cycles + 1 - assert create_channel.call_count == num_cycles - assert replace_channel.call_count == num_cycles - for call in replace_channel.call_args_list: - (args, kwargs) = call - assert args[0] == channel_idx - assert kwargs["grace"] == expected_grace - assert kwargs["new_channel"] == new_channel - client.close() + return self._get_target_class()(table, **kwargs) - def test__register_instance(self): - """test instance registration""" - client_mock = mock.Mock() - client_mock._gapic_client.instance_path.side_effect = lambda a, b: f"prefix/{b}" - active_instances = set() - instance_owners = {} - client_mock._active_instances = active_instances - client_mock._instance_owners = instance_owners - client_mock._channel_refresh_tasks = [] - client_mock._start_background_channel_refresh.side_effect = ( - lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def test_ctor_defaults(self): + with mock.patch.object( + self._get_target_class(), + "_timer_routine", + return_value=concurrent.futures.Future(), + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = [Exception] + with self._make_one(table) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._max_mutation_count == 100000 + assert instance._flow_control._max_mutation_bytes == 104857600 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert ( + instance._operation_timeout + == table.default_mutate_rows_operation_timeout + ) + assert ( + instance._attempt_timeout + == table.default_mutate_rows_attempt_timeout + ) + assert ( + instance._retryable_errors + == table.default_mutate_rows_retryable_errors + ) + time.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == 5 + assert isinstance(instance._flush_timer, concurrent.futures.Future) + + def test_ctor_explicit(self): + """Test with explicit parameters""" + with mock.patch.object( + self._get_target_class(), + "_timer_routine", + return_value=concurrent.futures.Future(), + ) as flush_timer_mock: + table = mock.Mock() + flush_interval = 20 + flush_limit_count = 17 + flush_limit_bytes = 19 + flow_control_max_mutation_count = 1001 + flow_control_max_bytes = 12 + operation_timeout = 11 + attempt_timeout = 2 + retryable_errors = [Exception] + with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=operation_timeout, + batch_attempt_timeout=attempt_timeout, + batch_retryable_errors=retryable_errors, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert ( + instance._flow_control._max_mutation_count + == flow_control_max_mutation_count + ) + assert ( + instance._flow_control._max_mutation_bytes == flow_control_max_bytes + ) + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert instance._operation_timeout == operation_timeout + assert instance._attempt_timeout == attempt_timeout + assert instance._retryable_errors == retryable_errors + time.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == flush_interval + assert isinstance(instance._flush_timer, concurrent.futures.Future) + + def test_ctor_no_flush_limits(self): + """Test with None for flush limits""" + with mock.patch.object( + self._get_target_class(), + "_timer_routine", + return_value=concurrent.futures.Future(), + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = () + flush_interval = None + flush_limit_count = None + flush_limit_bytes = None + with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._staged_entries == [] + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + time.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] is None + assert isinstance(instance._flush_timer, concurrent.futures.Future) + + def test_ctor_invalid_values(self): + """Test that timeout values are positive, and fit within expected limits""" + with pytest.raises(ValueError) as e: + self._make_one(batch_operation_timeout=-1) + assert "operation_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._make_one(batch_attempt_timeout=-1) + assert "attempt_timeout must be greater than 0" in str(e.value) + + def test_default_argument_consistency(self): + """ + We supply default arguments in MutationsBatcherAsync.__init__, and in + table.mutations_batcher. Make sure any changes to defaults are applied to + both places + """ + import inspect + + get_batcher_signature = dict( + inspect.signature(Table.mutations_batcher).parameters ) - mock_channels = [mock.Mock() for i in range(5)] - client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = mock.Mock() - table_mock = mock.Mock() - self._get_target_class()._register_instance( - client_mock, "instance-1", table_mock + get_batcher_signature.pop("self") + batcher_init_signature = dict( + inspect.signature(self._get_target_class()).parameters ) - assert client_mock._start_background_channel_refresh.call_count == 1 - expected_key = ( - "prefix/instance-1", - table_mock.table_name, - table_mock.app_profile_id, + batcher_init_signature.pop("table") + assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) + assert len(get_batcher_signature) == 8 + assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) + for arg_name in get_batcher_signature.keys(): + assert ( + get_batcher_signature[arg_name].default + == batcher_init_signature[arg_name].default + ) + + @pytest.mark.parametrize("input_val", [None, 0, -1]) + def test__start_flush_timer_w_empty_input(self, input_val): + """Empty/invalid timer should return immediately""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + with self._make_one() as instance: + if self.is_async(): + (sleep_obj, sleep_method) = (asyncio, "wait_for") + else: + (sleep_obj, sleep_method) = (instance._closed, "wait") + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + result = instance._timer_routine(input_val) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + assert result is None + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test__start_flush_timer_call_when_closed(self): + """closed batcher's timer should return immediately""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + with self._make_one() as instance: + instance.close() + flush_mock.reset_mock() + if self.is_async(): + (sleep_obj, sleep_method) = (asyncio, "wait_for") + else: + (sleep_obj, sleep_method) = (instance._closed, "wait") + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + instance._timer_routine(10) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + + @pytest.mark.parametrize("num_staged", [0, 1, 10]) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test__flush_timer(self, num_staged): + """Timer should continue to call _schedule_flush in a loop""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + expected_sleep = 12 + with self._make_one(flush_interval=expected_sleep) as instance: + loop_num = 3 + instance._staged_entries = [mock.Mock()] * num_staged + if self.is_async(): + (sleep_obj, sleep_method) = (asyncio, "wait_for") + else: + (sleep_obj, sleep_method) = (instance._closed, "wait") + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] + with pytest.raises(TabError): + self._get_target_class()._timer_routine( + instance, expected_sleep + ) + instance._flush_timer = concurrent.futures.Future() + assert sleep_mock.call_count == loop_num + 1 + sleep_kwargs = sleep_mock.call_args[1] + assert sleep_kwargs["timeout"] == expected_sleep + assert flush_mock.call_count == (0 if num_staged == 0 else loop_num) + + def test__flush_timer_close(self): + """Timer should continue terminate after close""" + with mock.patch.object(self._get_target_class(), "_schedule_flush"): + with self._make_one() as instance: + with mock.patch("asyncio.sleep"): + time.sleep(0.5) + assert instance._flush_timer.done() is False + instance.close() + time.sleep(0.1) + assert instance._flush_timer.done() is True + + def test_append_closed(self): + """Should raise exception""" + instance = self._make_one() + instance.close() + with pytest.raises(RuntimeError): + instance.append(mock.Mock()) + + def test_append_wrong_mutation(self): + """ + Mutation objects should raise an exception. + Only support RowMutationEntry + """ + from google.cloud.bigtable.data.mutations import DeleteAllFromRow + + with self._make_one() as instance: + expected_error = "invalid mutation type: DeleteAllFromRow. Only RowMutationEntry objects are supported by batcher" + with pytest.raises(ValueError) as e: + instance.append(DeleteAllFromRow()) + assert str(e.value) == expected_error + + def test_append_outside_flow_limits(self): + """entries larger than mutation limits are still processed""" + with self._make_one( + flow_control_max_mutation_count=1, flow_control_max_bytes=1 + ) as instance: + oversized_entry = self._make_mutation(count=0, size=2) + instance.append(oversized_entry) + assert instance._staged_entries == [oversized_entry] + assert instance._staged_count == 0 + assert instance._staged_bytes == 2 + instance._staged_entries = [] + with self._make_one( + flow_control_max_mutation_count=1, flow_control_max_bytes=1 + ) as instance: + overcount_entry = self._make_mutation(count=2, size=0) + instance.append(overcount_entry) + assert instance._staged_entries == [overcount_entry] + assert instance._staged_count == 2 + assert instance._staged_bytes == 0 + instance._staged_entries = [] + + def test_append_flush_runs_after_limit_hit(self): + """ + If the user appends a bunch of entries above the flush limits back-to-back, + it should still flush in a single task + """ + with mock.patch.object( + self._get_target_class(), "_execute_mutate_rows" + ) as op_mock: + with self._make_one(flush_limit_bytes=100) as instance: + + def mock_call(*args, **kwargs): + return [] + + op_mock.side_effect = mock_call + instance.append(self._make_mutation(size=99)) + num_entries = 10 + for _ in range(num_entries): + instance.append(self._make_mutation(size=1)) + instance._wait_for_batch_results(*instance._flush_jobs) + assert op_mock.call_count == 1 + sent_batch = op_mock.call_args[0][0] + assert len(sent_batch) == 2 + assert len(instance._staged_entries) == num_entries - 1 + + @pytest.mark.parametrize( + "flush_count,flush_bytes,mutation_count,mutation_bytes,expect_flush", + [ + (10, 10, 1, 1, False), + (10, 10, 9, 9, False), + (10, 10, 10, 1, True), + (10, 10, 1, 10, True), + (10, 10, 10, 10, True), + (1, 1, 10, 10, True), + (1, 1, 0, 0, False), + ], + ) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_append( + self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush + ): + """test appending different mutations, and checking if it causes a flush""" + with self._make_one( + flush_limit_mutation_count=flush_count, flush_limit_bytes=flush_bytes + ) as instance: + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert instance._staged_entries == [] + mutation = self._make_mutation(count=mutation_count, size=mutation_bytes) + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + instance.append(mutation) + assert flush_mock.call_count == bool(expect_flush) + assert instance._staged_count == mutation_count + assert instance._staged_bytes == mutation_bytes + assert instance._staged_entries == [mutation] + instance._staged_entries = [] + + def test_append_multiple_sequentially(self): + """Append multiple mutations""" + with self._make_one( + flush_limit_mutation_count=8, flush_limit_bytes=8 + ) as instance: + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert instance._staged_entries == [] + mutation = self._make_mutation(count=2, size=3) + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + instance.append(mutation) + assert flush_mock.call_count == 0 + assert instance._staged_count == 2 + assert instance._staged_bytes == 3 + assert len(instance._staged_entries) == 1 + instance.append(mutation) + assert flush_mock.call_count == 0 + assert instance._staged_count == 4 + assert instance._staged_bytes == 6 + assert len(instance._staged_entries) == 2 + instance.append(mutation) + assert flush_mock.call_count == 1 + assert instance._staged_count == 6 + assert instance._staged_bytes == 9 + assert len(instance._staged_entries) == 3 + instance._staged_entries = [] + + def test_flush_flow_control_concurrent_requests(self): + """requests should happen in parallel if flow control breaks up single flush into batches""" + import time + + num_calls = 10 + fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] + with self._make_one(flow_control_max_mutation_count=1) as instance: + with mock.patch.object( + instance, "_execute_mutate_rows", mock.Mock() + ) as op_mock: + + def mock_call(*args, **kwargs): + time.sleep(0.1) + return [] + + op_mock.side_effect = mock_call + start_time = time.monotonic() + instance._staged_entries = fake_mutations + instance._schedule_flush() + time.sleep(0.01) + for i in range(num_calls): + instance._flow_control.remove_from_flow( + [self._make_mutation(count=1)] + ) + time.sleep(0.01) + instance._wait_for_batch_results(*instance._flush_jobs) + duration = time.monotonic() - start_time + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert duration < 0.5 + assert op_mock.call_count == num_calls + + def test_schedule_flush_no_mutations(self): + """schedule flush should return None if no staged mutations""" + with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal") as flush_mock: + for i in range(3): + assert instance._schedule_flush() is None + assert flush_mock.call_count == 0 + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_schedule_flush_with_mutations(self): + """if new mutations exist, should add a new flush task to _flush_jobs""" + with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal") as flush_mock: + if not self.is_async(): + flush_mock.side_effect = lambda x: time.sleep(0.1) + for i in range(1, 4): + mutation = mock.Mock() + instance._staged_entries = [mutation] + instance._schedule_flush() + assert instance._staged_entries == [] + time.sleep(0) + assert instance._staged_entries == [] + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert flush_mock.call_count == 1 + flush_mock.reset_mock() + + def test__flush_internal(self): + """ + _flush_internal should: + - await previous flush call + - delegate batching to _flow_control + - call _execute_mutate_rows on each batch + - update self.exceptions and self._entries_processed_since_last_raise + """ + num_entries = 10 + with self._make_one() as instance: + with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: + with mock.patch.object( + instance._flow_control, "add_to_flow" + ) as flow_mock: + + def gen(x): + yield x + + flow_mock.side_effect = lambda x: gen(x) + mutations = [self._make_mutation(count=1, size=1)] * num_entries + instance._flush_internal(mutations) + assert instance._entries_processed_since_last_raise == num_entries + assert execute_mock.call_count == 1 + assert flow_mock.call_count == 1 + instance._oldest_exceptions.clear() + instance._newest_exceptions.clear() + + def test_flush_clears_job_list(self): + """ + a job should be added to _flush_jobs when _schedule_flush is called, + and removed when it completes + """ + with self._make_one() as instance: + with mock.patch.object( + instance, "_flush_internal", mock.Mock() + ) as flush_mock: + if not self.is_async(): + flush_mock.side_effect = lambda x: time.sleep(0.1) + mutations = [self._make_mutation(count=1, size=1)] + instance._staged_entries = mutations + assert instance._flush_jobs == set() + new_job = instance._schedule_flush() + assert instance._flush_jobs == {new_job} + if self.is_async(): + new_job + else: + new_job.result() + assert instance._flush_jobs == set() + + @pytest.mark.parametrize( + "num_starting,num_new_errors,expected_total_errors", + [ + (0, 0, 0), + (0, 1, 1), + (0, 2, 2), + (1, 0, 1), + (1, 1, 2), + (10, 2, 12), + (10, 20, 20), + ], + ) + def test__flush_internal_with_errors( + self, num_starting, num_new_errors, expected_total_errors + ): + """errors returned from _execute_mutate_rows should be added to internal exceptions""" + from google.cloud.bigtable.data import exceptions + + num_entries = 10 + expected_errors = [ + exceptions.FailedMutationEntryError(mock.Mock(), mock.Mock(), ValueError()) + ] * num_new_errors + with self._make_one() as instance: + instance._oldest_exceptions = [mock.Mock()] * num_starting + with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: + execute_mock.return_value = expected_errors + with mock.patch.object( + instance._flow_control, "add_to_flow" + ) as flow_mock: + + def gen(x): + yield x + + flow_mock.side_effect = lambda x: gen(x) + mutations = [self._make_mutation(count=1, size=1)] * num_entries + instance._flush_internal(mutations) + assert instance._entries_processed_since_last_raise == num_entries + assert execute_mock.call_count == 1 + assert flow_mock.call_count == 1 + found_exceptions = instance._oldest_exceptions + list( + instance._newest_exceptions + ) + assert len(found_exceptions) == expected_total_errors + for i in range(num_starting, expected_total_errors): + assert found_exceptions[i] == expected_errors[i - num_starting] + assert found_exceptions[i].index is None + instance._oldest_exceptions.clear() + instance._newest_exceptions.clear() + + def _mock_gapic_return(self, num=5): + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc import status_pb2 + + def gen(num): + for i in range(num): + entry = MutateRowsResponse.Entry( + index=i, status=status_pb2.Status(code=0) + ) + yield MutateRowsResponse(entries=[entry]) + + return gen(num) + + def test_timer_flush_end_to_end(self): + """Flush should automatically trigger after flush_interval""" + num_nutations = 10 + mutations = [self._make_mutation(count=2, size=2)] * num_nutations + with self._make_one(flush_interval=0.05) as instance: + instance._table.default_operation_timeout = 10 + instance._table.default_attempt_timeout = 9 + with mock.patch.object( + instance._table.client._gapic_client, "mutate_rows" + ) as gapic_mock: + gapic_mock.side_effect = ( + lambda *args, **kwargs: self._mock_gapic_return(num_nutations) + ) + for m in mutations: + instance.append(m) + assert instance._entries_processed_since_last_raise == 0 + time.sleep(0.1) + assert instance._entries_processed_since_last_raise == num_nutations + + def test__execute_mutate_rows(self): + if self.is_async(): + mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" + else: + mutate_path = "_sync._mutate_rows._MutateRowsOperation" + with mock.patch(f"google.cloud.bigtable.data.{mutate_path}") as mutate_rows: + mutate_rows.return_value = mock.Mock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = instance._execute_mutate_rows(batch) + assert start_operation.call_count == 1 + (args, kwargs) = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + kwargs["operation_timeout"] == 17 + kwargs["attempt_timeout"] == 13 + assert result == [] + + def test__execute_mutate_rows_returns_errors(self): + """Errors from operation should be retruned as list""" + from google.cloud.bigtable.data.exceptions import ( + MutationsExceptionGroup, + FailedMutationEntryError, ) - assert len(active_instances) == 1 - assert expected_key == tuple(list(active_instances)[0]) - assert len(instance_owners) == 1 - assert expected_key == tuple(list(instance_owners)[0]) - assert client_mock._channel_refresh_tasks - table_mock2 = mock.Mock() - self._get_target_class()._register_instance( - client_mock, "instance-2", table_mock2 + + if self.is_async(): + mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" + else: + mutate_path = "_sync._mutate_rows._MutateRowsOperation" + with mock.patch( + f"google.cloud.bigtable.data.{mutate_path}.start" + ) as mutate_rows: + err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) + err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + assert result[0].index is None + assert result[1].index is None + + def test__raise_exceptions(self): + """Raise exceptions and reset error state""" + from google.cloud.bigtable.data import exceptions + + expected_total = 1201 + expected_exceptions = [RuntimeError("mock")] * 3 + with self._make_one() as instance: + instance._oldest_exceptions = expected_exceptions + instance._entries_processed_since_last_raise = expected_total + try: + instance._raise_exceptions() + except exceptions.MutationsExceptionGroup as exc: + assert list(exc.exceptions) == expected_exceptions + assert str(expected_total) in str(exc) + assert instance._entries_processed_since_last_raise == 0 + (instance._oldest_exceptions, instance._newest_exceptions) = ([], []) + instance._raise_exceptions() + + def test___aenter__(self): + """Should return self""" + with self._make_one() as instance: + assert instance.__enter__() == instance + + def test___aexit__(self): + """aexit should call close""" + with self._make_one() as instance: + with mock.patch.object(instance, "close") as close_mock: + instance.__exit__(None, None, None) + assert close_mock.call_count == 1 + + def test_close(self): + """Should clean up all resources""" + with self._make_one() as instance: + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + with mock.patch.object(instance, "_raise_exceptions") as raise_mock: + instance.close() + assert instance.closed is True + assert instance._flush_timer.done() is True + assert instance._flush_jobs == set() + assert flush_mock.call_count == 1 + assert raise_mock.call_count == 1 + + def test_close_w_exceptions(self): + """Raise exceptions on close""" + from google.cloud.bigtable.data import exceptions + + expected_total = 10 + expected_exceptions = [RuntimeError("mock")] + with self._make_one() as instance: + instance._oldest_exceptions = expected_exceptions + instance._entries_processed_since_last_raise = expected_total + try: + instance.close() + except exceptions.MutationsExceptionGroup as exc: + assert list(exc.exceptions) == expected_exceptions + assert str(expected_total) in str(exc) + assert instance._entries_processed_since_last_raise == 0 + (instance._oldest_exceptions, instance._newest_exceptions) = ([], []) + + def test__on_exit(self, recwarn): + """Should raise warnings if unflushed mutations exist""" + with self._make_one() as instance: + instance._on_exit() + assert len(recwarn) == 0 + num_left = 4 + instance._staged_entries = [mock.Mock()] * num_left + with pytest.warns(UserWarning) as w: + instance._on_exit() + assert len(w) == 1 + assert "unflushed mutations" in str(w[0].message).lower() + assert str(num_left) in str(w[0].message) + instance._closed.set() + instance._on_exit() + assert len(recwarn) == 0 + instance._staged_entries = [] + + def test_atexit_registration(self): + """Should run _on_exit on program termination""" + import atexit + + with mock.patch.object(atexit, "register") as register_mock: + assert register_mock.call_count == 0 + with self._make_one(): + assert register_mock.call_count == 1 + + def test_timeout_args_passed(self): + """ + batch_operation_timeout and batch_attempt_timeout should be used + in api calls + """ + if self.is_async(): + mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" + else: + mutate_path = "_sync._mutate_rows._MutateRowsOperation" + with mock.patch( + f"google.cloud.bigtable.data.{mutate_path}", return_value=mock.Mock() + ) as mutate_rows: + expected_operation_timeout = 17 + expected_attempt_timeout = 13 + with self._make_one( + batch_operation_timeout=expected_operation_timeout, + batch_attempt_timeout=expected_attempt_timeout, + ) as instance: + assert instance._operation_timeout == expected_operation_timeout + assert instance._attempt_timeout == expected_attempt_timeout + instance._execute_mutate_rows([self._make_mutation()]) + assert mutate_rows.call_count == 1 + kwargs = mutate_rows.call_args[1] + assert kwargs["operation_timeout"] == expected_operation_timeout + assert kwargs["attempt_timeout"] == expected_attempt_timeout + + @pytest.mark.parametrize( + "limit,in_e,start_e,end_e", + [ + (10, 0, (10, 0), (10, 0)), + (1, 10, (0, 0), (1, 1)), + (10, 1, (0, 0), (1, 0)), + (10, 10, (0, 0), (10, 0)), + (10, 11, (0, 0), (10, 1)), + (3, 20, (0, 0), (3, 3)), + (10, 20, (0, 0), (10, 10)), + (10, 21, (0, 0), (10, 10)), + (2, 1, (2, 0), (2, 1)), + (2, 1, (1, 0), (2, 0)), + (2, 2, (1, 0), (2, 1)), + (3, 1, (3, 1), (3, 2)), + (3, 3, (3, 1), (3, 3)), + (1000, 5, (999, 0), (1000, 4)), + (1000, 5, (0, 0), (5, 0)), + (1000, 5, (1000, 0), (1000, 5)), + ], + ) + def test__add_exceptions(self, limit, in_e, start_e, end_e): + """ + Test that the _add_exceptions function properly updates the + _oldest_exceptions and _newest_exceptions lists + Args: + - limit: the _exception_list_limit representing the max size of either list + - in_e: size of list of exceptions to send to _add_exceptions + - start_e: a tuple of ints representing the initial sizes of _oldest_exceptions and _newest_exceptions + - end_e: a tuple of ints representing the expected sizes of _oldest_exceptions and _newest_exceptions + """ + from collections import deque + + input_list = [RuntimeError(f"mock {i}") for i in range(in_e)] + mock_batcher = mock.Mock() + mock_batcher._oldest_exceptions = [ + RuntimeError(f"starting mock {i}") for i in range(start_e[0]) + ] + mock_batcher._newest_exceptions = deque( + [RuntimeError(f"starting mock {i}") for i in range(start_e[1])], + maxlen=limit, + ) + mock_batcher._exception_list_limit = limit + mock_batcher._exceptions_since_last_raise = 0 + self._get_target_class()._add_exceptions(mock_batcher, input_list) + assert len(mock_batcher._oldest_exceptions) == end_e[0] + assert len(mock_batcher._newest_exceptions) == end_e[1] + assert mock_batcher._exceptions_since_last_raise == in_e + oldest_list_diff = end_e[0] - start_e[0] + newest_list_diff = min(max(in_e - oldest_list_diff, 0), limit) + for i in range(oldest_list_diff): + assert mock_batcher._oldest_exceptions[i + start_e[0]] == input_list[i] + for i in range(1, newest_list_diff + 1): + assert mock_batcher._newest_exceptions[-i] == input_list[-i] + + @pytest.mark.parametrize( + "input_retryables,expected_retryables", + [ + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), + ], + ) + def test_customizable_retryable_errors(self, input_retryables, expected_retryables): + """ + Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer. + """ + retryn_fn = ( + "retry_target_async" + if "Async" in self._get_target_class().__name__ + else "retry_target" + ) + with mock.patch.object( + google.api_core.retry, "if_exception_type" + ) as predicate_builder_mock: + with mock.patch.object(google.api_core.retry, retryn_fn) as retry_fn_mock: + table = None + with mock.patch("asyncio.create_task"): + table = Table(mock.Mock(), "instance", "table") + with self._make_one( + table, batch_retryable_errors=input_retryables + ) as instance: + assert instance._retryable_errors == expected_retryables + expected_predicate = lambda a: a in expected_retryables + predicate_builder_mock.return_value = expected_predicate + retry_fn_mock.side_effect = RuntimeError("stop early") + mutation = self._make_mutation(count=1, size=1) + instance._execute_mutate_rows([mutation]) + predicate_builder_mock.assert_called_once_with( + *expected_retryables, _MutateRowsIncomplete + ) + retry_call_args = retry_fn_mock.call_args_list[0].args + assert retry_call_args[1] is expected_predicate + + +class TestBigtableDataClient(ABC): + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data._sync.client import BigtableDataClient + + return BigtableDataClient + + @classmethod + def _make_client(cls, *args, use_emulator=True, **kwargs): + import os + + env_mask = {} + if use_emulator: + env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" + import warnings + + warnings.filterwarnings("ignore", category=RuntimeWarning) + else: + kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) + kwargs["project"] = kwargs.get("project", "project-id") + with mock.patch.dict(os.environ, env_mask): + return cls._get_target_class()(*args, **kwargs) + + @property + def is_async(self): + return False + + def test_ctor(self): + expected_project = "project-id" + expected_pool_size = 11 + expected_credentials = AnonymousCredentials() + client = self._make_client( + project="project-id", + pool_size=expected_pool_size, + credentials=expected_credentials, + use_emulator=False, + ) + time.sleep(0) + assert client.project == expected_project + assert len(client.transport._grpc_channel._pool) == expected_pool_size + assert not client._active_instances + assert len(client._channel_refresh_tasks) == expected_pool_size + assert client.transport._credentials == expected_credentials + client.close() + + def test_ctor_super_inits(self): + from google.cloud.client import ClientWithProject + from google.api_core import client_options as client_options_lib + from google.cloud.bigtable import __version__ as bigtable_version + + project = "project-id" + pool_size = 11 + credentials = AnonymousCredentials() + client_options = {"api_endpoint": "foo.bar:1234"} + options_parsed = client_options_lib.from_dict(client_options) + asyncio_portion = "-async" if self.is_async else "" + transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" + with mock.patch.object(BigtableClient, "__init__") as bigtable_client_init: + bigtable_client_init.return_value = None + with mock.patch.object( + ClientWithProject, "__init__" + ) as client_project_init: + client_project_init.return_value = None + try: + self._make_client( + project=project, + pool_size=pool_size, + credentials=credentials, + client_options=options_parsed, + use_emulator=False, + ) + except AttributeError: + pass + assert bigtable_client_init.call_count == 1 + kwargs = bigtable_client_init.call_args[1] + assert kwargs["transport"] == transport_str + assert kwargs["credentials"] == credentials + assert kwargs["client_options"] == options_parsed + assert client_project_init.call_count == 1 + kwargs = client_project_init.call_args[1] + assert kwargs["project"] == project + assert kwargs["credentials"] == credentials + assert kwargs["client_options"] == options_parsed + + def test_ctor_dict_options(self): + from google.api_core.client_options import ClientOptions + + client_options = {"api_endpoint": "foo.bar:1234"} + with mock.patch.object(BigtableClient, "__init__") as bigtable_client_init: + try: + self._make_client(client_options=client_options) + except TypeError: + pass + bigtable_client_init.assert_called_once() + kwargs = bigtable_client_init.call_args[1] + called_options = kwargs["client_options"] + assert called_options.api_endpoint == "foo.bar:1234" + assert isinstance(called_options, ClientOptions) + with mock.patch.object( + self._get_target_class(), "_start_background_channel_refresh" + ) as start_background_refresh: + client = self._make_client( + client_options=client_options, use_emulator=False + ) + start_background_refresh.assert_called_once() + client.close() + + def test_veneer_grpc_headers(self): + client_component = "data-async" if self.is_async else "data" + VENEER_HEADER_REGEX = re.compile( + "gapic\\/[0-9]+\\.[\\w.-]+ gax\\/[0-9]+\\.[\\w.-]+ gccl\\/[0-9]+\\.[\\w.-]+-" + + client_component + + " gl-python\\/[0-9]+\\.[\\w.-]+ grpc\\/[0-9]+\\.[\\w.-]+" + ) + if self.is_async: + patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") + else: + patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") + with patch as gapic_mock: + client = self._make_client(project="project-id") + wrapped_call_list = gapic_mock.call_args_list + assert len(wrapped_call_list) > 0 + for call in wrapped_call_list: + client_info = call.kwargs["client_info"] + assert client_info is not None, f"{call} has no client_info" + wrapped_user_agent_sorted = " ".join( + sorted(client_info.to_user_agent().split(" ")) + ) + assert VENEER_HEADER_REGEX.match( + wrapped_user_agent_sorted + ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" + client.close() + + def test_channel_pool_creation(self): + pool_size = 14 + with mock.patch.object( + grpc_helpers, "create_channel", mock.Mock() + ) as create_channel: + client = self._make_client(project="project-id", pool_size=pool_size) + assert create_channel.call_count == pool_size + client.close() + client = self._make_client(project="project-id", pool_size=pool_size) + pool_list = list(client.transport._grpc_channel._pool) + pool_set = set(client.transport._grpc_channel._pool) + assert len(pool_list) == len(pool_set) + client.close() + + def test_channel_pool_rotation(self): + pool_size = 7 + with mock.patch.object(PooledChannel, "next_channel") as next_channel: + client = self._make_client(project="project-id", pool_size=pool_size) + assert len(client.transport._grpc_channel._pool) == pool_size + next_channel.reset_mock() + with mock.patch.object( + type(client.transport._grpc_channel._pool[0]), "unary_unary" + ) as unary_unary: + channel_next = None + for i in range(pool_size): + channel_last = channel_next + channel_next = client.transport.grpc_channel._pool[i] + assert channel_last != channel_next + next_channel.return_value = channel_next + client.transport.ping_and_warm() + assert next_channel.call_count == i + 1 + unary_unary.assert_called_once() + unary_unary.reset_mock() + client.close() + + def test_channel_pool_replace(self): + import time + + sleep_module = asyncio if self.is_async else time + with mock.patch.object(sleep_module, "sleep"): + pool_size = 7 + client = self._make_client(project="project-id", pool_size=pool_size) + for replace_idx in range(pool_size): + start_pool = [ + channel for channel in client.transport._grpc_channel._pool + ] + grace_period = 9 + with mock.patch.object( + type(client.transport._grpc_channel._pool[-1]), "close" + ) as close: + new_channel = client.transport.create_channel() + client.transport.replace_channel( + replace_idx, grace=grace_period, new_channel=new_channel + ) + close.assert_called_once() + if self.is_async: + close.assert_called_once_with(grace=grace_period) + close.assert_called_once() + assert client.transport._grpc_channel._pool[replace_idx] == new_channel + for i in range(pool_size): + if i != replace_idx: + assert client.transport._grpc_channel._pool[i] == start_pool[i] + else: + assert client.transport._grpc_channel._pool[i] != start_pool[i] + client.close() + + def test__start_background_channel_refresh_tasks_exist(self): + client = self._make_client(project="project-id", use_emulator=False) + assert len(client._channel_refresh_tasks) > 0 + with mock.patch.object(asyncio, "create_task") as create_task: + client._start_background_channel_refresh() + create_task.assert_not_called() + client.close() + + @pytest.mark.parametrize("pool_size", [1, 3, 7]) + def test__start_background_channel_refresh(self, pool_size): + import concurrent.futures + + with mock.patch.object( + self._get_target_class(), "_ping_and_warm_instances", mock.Mock() + ) as ping_and_warm: + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + client._start_background_channel_refresh() + assert len(client._channel_refresh_tasks) == pool_size + for task in client._channel_refresh_tasks: + if self.is_async: + assert isinstance(task, asyncio.Task) + else: + assert isinstance(task, concurrent.futures.Future) + time.sleep(0.1) + assert ping_and_warm.call_count == pool_size + for channel in client.transport._grpc_channel._pool: + ping_and_warm.assert_any_call(channel) + client.close() + + def test__ping_and_warm_instances(self): + """test ping and warm with mocked asyncio.gather""" + client_mock = mock.Mock() + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + gather_tuple = ( + (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") + ) + with mock.patch.object(*gather_tuple, mock.Mock()) as gather: + if self.is_async: + gather.side_effect = lambda *args, **kwargs: [None for _ in args] + else: + gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] + channel = mock.Mock() + client_mock._active_instances = [] + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel + ) + assert len(result) == 0 + if self.is_async: + assert gather.call_args.kwargs == {"return_exceptions": True} + client_mock._active_instances = [ + (mock.Mock(), mock.Mock(), mock.Mock()) + ] * 4 + gather.reset_mock() + channel.reset_mock() + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel + ) + assert len(result) == 4 + if self.is_async: + gather.assert_called_once() + gather.assert_called_once() + assert len(gather.call_args.args) == 4 + else: + assert gather.call_count == 4 + grpc_call_args = channel.unary_unary().call_args_list + for idx, (_, kwargs) in enumerate(grpc_call_args): + ( + expected_instance, + expected_table, + expected_app_profile, + ) = client_mock._active_instances[idx] + request = kwargs["request"] + assert request["name"] == expected_instance + assert request["app_profile_id"] == expected_app_profile + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert ( + metadata[0][1] + == f"name={expected_instance}&app_profile_id={expected_app_profile}" + ) + + def test__ping_and_warm_single_instance(self): + """should be able to call ping and warm with single instance""" + client_mock = mock.Mock() + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + gather_tuple = ( + (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") + ) + with mock.patch.object(*gather_tuple, mock.Mock()) as gather: + gather.side_effect = lambda *args, **kwargs: [mock.Mock() for _ in args] + if self.is_async: + gather.side_effect = lambda *args, **kwargs: [None for _ in args] + else: + gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] + channel = mock.Mock() + client_mock._active_instances = [mock.Mock()] * 100 + test_key = ("test-instance", "test-table", "test-app-profile") + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel, test_key + ) + assert len(result) == 1 + grpc_call_args = channel.unary_unary().call_args_list + assert len(grpc_call_args) == 1 + kwargs = grpc_call_args[0][1] + request = kwargs["request"] + assert request["name"] == "test-instance" + assert request["app_profile_id"] == "test-app-profile" + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert ( + metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" + ) + + @pytest.mark.parametrize( + "refresh_interval, wait_time, expected_sleep", + [(0, 0, 0), (0, 1, 0), (10, 0, 10), (10, 5, 5), (10, 10, 0), (10, 15, 0)], + ) + def test__manage_channel_first_sleep( + self, refresh_interval, wait_time, expected_sleep + ): + import threading + import time + + with mock.patch.object(time, "monotonic") as monotonic: + monotonic.return_value = 0 + sleep_tuple = ( + (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: + sleep.side_effect = asyncio.CancelledError + try: + client = self._make_client(project="project-id") + client._channel_init_time = -wait_time + client._manage_channel(0, refresh_interval, refresh_interval) + except asyncio.CancelledError: + pass + sleep.assert_called_once() + call_time = sleep.call_args[0][0] + assert ( + abs(call_time - expected_sleep) < 0.1 + ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" + client.close() + + def test__manage_channel_ping_and_warm(self): + """_manage channel should call ping and warm internally""" + import time + import threading + + client_mock = mock.Mock() + client_mock._is_closed.is_set.return_value = False + client_mock._channel_init_time = time.monotonic() + channel_list = [mock.Mock(), mock.Mock()] + client_mock.transport.channels = channel_list + new_channel = mock.Mock() + client_mock.transport.grpc_channel._create_channel.return_value = new_channel + sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + with mock.patch.object(*sleep_tuple): + client_mock.transport.replace_channel.side_effect = asyncio.CancelledError + ping_and_warm = client_mock._ping_and_warm_instances = mock.Mock() + try: + channel_idx = 1 + self._get_target_class()._manage_channel(client_mock, channel_idx, 10) + except asyncio.CancelledError: + pass + assert ping_and_warm.call_count == 2 + assert client_mock.transport.replace_channel.call_count == 1 + old_channel = channel_list[channel_idx] + assert old_channel != new_channel + called_with = [call[0][0] for call in ping_and_warm.call_args_list] + assert old_channel in called_with + assert new_channel in called_with + ping_and_warm.reset_mock() + try: + self._get_target_class()._manage_channel(client_mock, 0, 0, 0) + except asyncio.CancelledError: + pass + ping_and_warm.assert_called_once_with(new_channel) + + @pytest.mark.parametrize( + "refresh_interval, num_cycles, expected_sleep", + [(None, 1, 60 * 35), (10, 10, 100), (10, 1, 10)], + ) + def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sleep): + import time + import random + import threading + + channel_idx = 1 + with mock.patch.object(random, "uniform") as uniform: + uniform.side_effect = lambda min_, max_: min_ + with mock.patch.object(time, "time") as time_mock: + time_mock.return_value = 0 + sleep_tuple = ( + (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: + sleep.side_effect = [None for i in range(num_cycles - 1)] + [ + asyncio.CancelledError + ] + client = self._make_client(project="project-id") + with mock.patch.object(client.transport, "replace_channel"): + try: + if refresh_interval is not None: + client._manage_channel( + channel_idx, refresh_interval, refresh_interval + ) + else: + client._manage_channel(channel_idx) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + total_sleep = sum([call[0][0] for call in sleep.call_args_list]) + assert ( + abs(total_sleep - expected_sleep) < 0.1 + ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" + client.close() + + def test__manage_channel_random(self): + import random + import threading + + sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + with mock.patch.object(*sleep_tuple) as sleep: + with mock.patch.object(random, "uniform") as uniform: + uniform.return_value = 0 + try: + uniform.side_effect = asyncio.CancelledError + client = self._make_client(project="project-id", pool_size=1) + except asyncio.CancelledError: + uniform.side_effect = None + uniform.reset_mock() + sleep.reset_mock() + min_val = 200 + max_val = 205 + uniform.side_effect = lambda min_, max_: min_ + sleep.side_effect = [None, None, asyncio.CancelledError] + try: + with mock.patch.object(client.transport, "replace_channel"): + client._manage_channel(0, min_val, max_val) + except asyncio.CancelledError: + pass + assert uniform.call_count == 3 + uniform_args = [call[0] for call in uniform.call_args_list] + for found_min, found_max in uniform_args: + assert found_min == min_val + assert found_max == max_val + + @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) + def test__manage_channel_refresh(self, num_cycles): + import threading + + expected_grace = 9 + expected_refresh = 0.5 + channel_idx = 1 + grpc_lib = grpc.aio if self.is_async else grpc + new_channel = grpc_lib.insecure_channel("localhost:8080") + with mock.patch.object( + PooledBigtableGrpcTransport, "replace_channel" + ) as replace_channel: + sleep_tuple = ( + (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: + sleep.side_effect = [None for i in range(num_cycles)] + [ + asyncio.CancelledError + ] + with mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + create_channel.return_value = new_channel + with mock.patch.object( + self._get_target_class(), "_start_background_channel_refresh" + ): + client = self._make_client( + project="project-id", use_emulator=False + ) + create_channel.reset_mock() + try: + client._manage_channel( + channel_idx, + refresh_interval_min=expected_refresh, + refresh_interval_max=expected_refresh, + grace_period=expected_grace, + ) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + 1 + assert create_channel.call_count == num_cycles + assert replace_channel.call_count == num_cycles + for call in replace_channel.call_args_list: + (args, kwargs) = call + assert args[0] == channel_idx + assert kwargs["grace"] == expected_grace + assert kwargs["new_channel"] == new_channel + client.close() + + def test__register_instance(self): + """test instance registration""" + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: f"prefix/{b}" + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_tasks = [] + client_mock._start_background_channel_refresh.side_effect = ( + lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + ) + mock_channels = [mock.Mock() for i in range(5)] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = mock.Mock() + table_mock = mock.Mock() + self._get_target_class()._register_instance( + client_mock, "instance-1", table_mock + ) + assert client_mock._start_background_channel_refresh.call_count == 1 + expected_key = ( + "prefix/instance-1", + table_mock.table_name, + table_mock.app_profile_id, + ) + assert len(active_instances) == 1 + assert expected_key == tuple(list(active_instances)[0]) + assert len(instance_owners) == 1 + assert expected_key == tuple(list(instance_owners)[0]) + assert client_mock._channel_refresh_tasks + table_mock2 = mock.Mock() + self._get_target_class()._register_instance( + client_mock, "instance-2", table_mock2 + ) + assert client_mock._start_background_channel_refresh.call_count == 1 + assert client_mock._ping_and_warm_instances.call_count == len(mock_channels) + for channel in mock_channels: + assert channel in [ + call[0][0] + for call in client_mock._ping_and_warm_instances.call_args_list + ] + assert len(active_instances) == 2 + assert len(instance_owners) == 2 + expected_key2 = ( + "prefix/instance-2", + table_mock2.table_name, + table_mock2.app_profile_id, + ) + assert any( + [ + expected_key2 == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) + assert any( + [ + expected_key2 == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) + + @pytest.mark.parametrize( + "insert_instances,expected_active,expected_owner_keys", + [ + ([("i", "t", None)], [("i", "t", None)], [("i", "t", None)]), + ([("i", "t", "p")], [("i", "t", "p")], [("i", "t", "p")]), + ([("1", "t", "p"), ("1", "t", "p")], [("1", "t", "p")], [("1", "t", "p")]), + ( + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + ), + ], + ) + def test__register_instance_state( + self, insert_instances, expected_active, expected_owner_keys + ): + """test that active_instances and instance_owners are updated as expected""" + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: b + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_tasks = [] + client_mock._start_background_channel_refresh.side_effect = ( + lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + ) + mock_channels = [mock.Mock() for i in range(5)] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = mock.Mock() + table_mock = mock.Mock() + for instance, table, profile in insert_instances: + table_mock.table_name = table + table_mock.app_profile_id = profile + self._get_target_class()._register_instance( + client_mock, instance, table_mock + ) + assert len(active_instances) == len(expected_active) + assert len(instance_owners) == len(expected_owner_keys) + for expected in expected_active: + assert any( + [ + expected == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) + for expected in expected_owner_keys: + assert any( + [ + expected == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) + + def test__remove_instance_registration(self): + client = self._make_client(project="project-id") + table = mock.Mock() + client._register_instance("instance-1", table) + client._register_instance("instance-2", table) + assert len(client._active_instances) == 2 + assert len(client._instance_owners.keys()) == 2 + instance_1_path = client._gapic_client.instance_path( + client.project, "instance-1" + ) + instance_1_key = (instance_1_path, table.table_name, table.app_profile_id) + instance_2_path = client._gapic_client.instance_path( + client.project, "instance-2" + ) + instance_2_key = (instance_2_path, table.table_name, table.app_profile_id) + assert len(client._instance_owners[instance_1_key]) == 1 + assert list(client._instance_owners[instance_1_key])[0] == id(table) + assert len(client._instance_owners[instance_2_key]) == 1 + assert list(client._instance_owners[instance_2_key])[0] == id(table) + success = client._remove_instance_registration("instance-1", table) + assert success + assert len(client._active_instances) == 1 + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 1 + assert client._active_instances == {instance_2_key} + success = client._remove_instance_registration("fake-key", table) + assert not success + assert len(client._active_instances) == 1 + client.close() + + def test__multiple_table_registration(self): + """ + registering with multiple tables with the same key should + add multiple owners to instance_owners, but only keep one copy + of shared key in active_instances + """ + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + with self._make_client(project="project-id") as client: + with client.get_table("instance_1", "table_1") as table_1: + instance_1_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 1 + assert len(client._active_instances) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + with client.get_table("instance_1", "table_1") as table_2: + assert len(client._instance_owners[instance_1_key]) == 2 + assert len(client._active_instances) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + with client.get_table("instance_1", "table_3") as table_3: + instance_3_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_3_key = _WarmedInstanceKey( + instance_3_path, table_3.table_name, table_3.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 2 + assert len(client._instance_owners[instance_3_key]) == 1 + assert len(client._active_instances) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + assert id(table_3) in client._instance_owners[instance_3_key] + assert len(client._active_instances) == 1 + assert instance_1_key in client._active_instances + assert id(table_2) not in client._instance_owners[instance_1_key] + assert len(client._active_instances) == 0 + assert instance_1_key not in client._active_instances + assert len(client._instance_owners[instance_1_key]) == 0 + + def test__multiple_instance_registration(self): + """ + registering with multiple instance keys should update the key + in instance_owners and active_instances + """ + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + with self._make_client(project="project-id") as client: + with client.get_table("instance_1", "table_1") as table_1: + with client.get_table("instance_2", "table_2") as table_2: + instance_1_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id + ) + instance_2_path = client._gapic_client.instance_path( + client.project, "instance_2" + ) + instance_2_key = _WarmedInstanceKey( + instance_2_path, table_2.table_name, table_2.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 1 + assert len(client._instance_owners[instance_2_key]) == 1 + assert len(client._active_instances) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_2_key] + assert len(client._active_instances) == 1 + assert instance_1_key in client._active_instances + assert len(client._instance_owners[instance_2_key]) == 0 + assert len(client._instance_owners[instance_1_key]) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + assert len(client._active_instances) == 0 + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 0 + + def test_get_table(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + client = self._make_client(project="project-id") + assert not client._active_instances + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + table = client.get_table( + expected_instance_id, expected_table_id, expected_app_profile_id + ) + time.sleep(0) + assert isinstance(table, TestTable._get_target_class()) + assert table.table_id == expected_table_id + assert ( + table.table_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert table.instance_id == expected_instance_id + assert ( + table.instance_name + == f"projects/{client.project}/instances/{expected_instance_id}" + ) + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + client.close() + + def test_get_table_arg_passthrough(self): + """All arguments passed in get_table should be sent to constructor""" + with self._make_client(project="project-id") as client: + with mock.patch.object( + TestTable._get_target_class(), "__init__" + ) as mock_constructor: + mock_constructor.return_value = None + assert not client._active_instances + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_args = (1, "test", {"test": 2}) + expected_kwargs = {"hello": "world", "test": 2} + client.get_table( + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, + ) + mock_constructor.assert_called_once_with( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, + ) + + def test_get_table_context_manager(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_project_id = "project-id" + with mock.patch.object(TestTable._get_target_class(), "close") as close_mock: + with self._make_client(project=expected_project_id) as client: + with client.get_table( + expected_instance_id, expected_table_id, expected_app_profile_id + ) as table: + time.sleep(0) + assert isinstance(table, TestTable._get_target_class()) + assert table.table_id == expected_table_id + assert ( + table.table_name + == f"projects/{expected_project_id}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert table.instance_id == expected_instance_id + assert ( + table.instance_name + == f"projects/{expected_project_id}/instances/{expected_instance_id}" + ) + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + assert close_mock.call_count == 1 + + def test_multiple_pool_sizes(self): + pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] + for pool_size in pool_sizes: + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client._channel_refresh_tasks) == pool_size + client_duplicate = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client_duplicate._channel_refresh_tasks) == pool_size + assert str(pool_size) in str(client.transport) + client.close() + client_duplicate.close() + + def test_close(self): + pool_size = 7 + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client._channel_refresh_tasks) == pool_size + tasks_list = list(client._channel_refresh_tasks) + for task in client._channel_refresh_tasks: + assert not task.done() + with mock.patch.object( + PooledBigtableGrpcTransport, "close", mock.Mock() + ) as close_mock: + client.close() + close_mock.assert_called_once() + close_mock.assert_called_once() + for task in tasks_list: + assert task.done() + assert client._channel_refresh_tasks == [] + + def test_context_manager(self): + close_mock = mock.Mock() + true_close = None + with self._make_client(project="project-id") as client: + true_close = client.close() + client.close = close_mock + for task in client._channel_refresh_tasks: + assert not task.done() + assert client.project == "project-id" + assert client._active_instances == set() + close_mock.assert_not_called() + close_mock.assert_called_once() + close_mock.assert_called_once() + true_close + + +class TestTable(ABC): + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data._sync.client import Table + + return Table + + @property + def is_async(self): + return False + + def test_table_ctor(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_operation_timeout = 123 + expected_attempt_timeout = 12 + expected_read_rows_operation_timeout = 1.5 + expected_read_rows_attempt_timeout = 0.5 + expected_mutate_rows_operation_timeout = 2.5 + expected_mutate_rows_attempt_timeout = 0.75 + client = self._make_client() + assert not client._active_instances + table = self._get_target_class()( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + default_operation_timeout=expected_operation_timeout, + default_attempt_timeout=expected_attempt_timeout, + default_read_rows_operation_timeout=expected_read_rows_operation_timeout, + default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout, + default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, + ) + time.sleep(0) + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + assert table.default_operation_timeout == expected_operation_timeout + assert table.default_attempt_timeout == expected_attempt_timeout + assert ( + table.default_read_rows_operation_timeout + == expected_read_rows_operation_timeout + ) + assert ( + table.default_read_rows_attempt_timeout + == expected_read_rows_attempt_timeout + ) + assert ( + table.default_mutate_rows_operation_timeout + == expected_mutate_rows_operation_timeout + ) + assert ( + table.default_mutate_rows_attempt_timeout + == expected_mutate_rows_attempt_timeout + ) + table._register_instance_future + assert table._register_instance_future.done() + assert not table._register_instance_future.cancelled() + assert table._register_instance_future.exception() is None + client.close() + + def test_table_ctor_defaults(self): + """should provide default timeout values and app_profile_id""" + expected_table_id = "table-id" + expected_instance_id = "instance-id" + client = self._make_client() + assert not client._active_instances + table = Table(client, expected_instance_id, expected_table_id) + time.sleep(0) + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id is None + assert table.client is client + assert table.default_operation_timeout == 60 + assert table.default_read_rows_operation_timeout == 600 + assert table.default_mutate_rows_operation_timeout == 600 + assert table.default_attempt_timeout == 20 + assert table.default_read_rows_attempt_timeout == 20 + assert table.default_mutate_rows_attempt_timeout == 60 + client.close() + + def test_table_ctor_invalid_timeout_values(self): + """bad timeout values should raise ValueError""" + client = self._make_client() + timeout_pairs = [ + ("default_operation_timeout", "default_attempt_timeout"), + ( + "default_read_rows_operation_timeout", + "default_read_rows_attempt_timeout", + ), + ( + "default_mutate_rows_operation_timeout", + "default_mutate_rows_attempt_timeout", + ), + ] + for operation_timeout, attempt_timeout in timeout_pairs: + with pytest.raises(ValueError) as e: + Table(client, "", "", **{attempt_timeout: -1}) + assert "attempt_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + Table(client, "", "", **{operation_timeout: -1}) + assert "operation_timeout must be greater than 0" in str(e.value) + client.close() + + @pytest.mark.parametrize( + "fn_name,fn_args,is_stream,extra_retryables", + [ + ("read_rows_stream", (ReadRowsQuery(),), True, ()), + ("read_rows", (ReadRowsQuery(),), True, ()), + ("read_row", (b"row_key",), True, ()), + ("read_rows_sharded", ([ReadRowsQuery()],), True, ()), + ("row_exists", (b"row_key",), True, ()), + ("sample_row_keys", (), False, ()), + ("mutate_row", (b"row_key", [mock.Mock()]), False, ()), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + False, + (_MutateRowsIncomplete,), + ), + ], + ) + @pytest.mark.parametrize( + "input_retryables,expected_retryables", + [ + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), + ], + ) + def test_customizable_retryable_errors( + self, + input_retryables, + expected_retryables, + fn_name, + fn_args, + is_stream, + extra_retryables, + ): + """ + Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer. + """ + retry_fn = "retry_target" + if is_stream: + retry_fn += "_stream" + if self.is_async: + retry_fn += "_async" + with mock.patch(f"google.api_core.retry.{retry_fn}") as retry_fn_mock: + with self._make_client() as client: + table = client.get_table("instance-id", "table-id") + expected_predicate = lambda a: a in expected_retryables + retry_fn_mock.side_effect = RuntimeError("stop early") + with mock.patch( + "google.api_core.retry.if_exception_type" + ) as predicate_builder_mock: + predicate_builder_mock.return_value = expected_predicate + with pytest.raises(Exception): + test_fn = table.__getattribute__(fn_name) + test_fn(*fn_args, retryable_errors=input_retryables) + predicate_builder_mock.assert_called_once_with( + *expected_retryables, *extra_retryables + ) + retry_call_args = retry_fn_mock.call_args_list[0].args + assert retry_call_args[1] is expected_predicate + + @pytest.mark.parametrize( + "fn_name,fn_args,gapic_fn", + [ + ("read_rows_stream", (ReadRowsQuery(),), "read_rows"), + ("read_rows", (ReadRowsQuery(),), "read_rows"), + ("read_row", (b"row_key",), "read_rows"), + ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), + ("row_exists", (b"row_key",), "read_rows"), + ("sample_row_keys", (), "sample_row_keys"), + ("mutate_row", (b"row_key", [mock.Mock()]), "mutate_row"), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + "mutate_rows", + ), + ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), + ( + "read_modify_write_row", + (b"row_key", mock.Mock()), + "read_modify_write_row", + ), + ], + ) + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): + """check that all requests attach proper metadata headers""" + profile = "profile" if include_app_profile else None + with mock.patch.object( + BigtableClient, gapic_fn, mock.mock.Mock() + ) as gapic_mock: + gapic_mock.side_effect = RuntimeError("stop early") + with self._make_client() as client: + table = Table(client, "instance-id", "table-id", profile) + try: + test_fn = table.__getattribute__(fn_name) + maybe_stream = test_fn(*fn_args) + [i for i in maybe_stream] + except Exception: + pass + kwargs = gapic_mock.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata + + +class TestReadRows(ABC): + """ + Tests for table.read_rows and related methods. + """ + + @staticmethod + def _get_operation_class(): + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + + return _ReadRowsOperation + + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + def _make_table(self, *args, **kwargs): + client_mock = mock.Mock() + client_mock._register_instance.side_effect = lambda *args, **kwargs: time.sleep( + 0 + ) + client_mock._remove_instance_registration.side_effect = ( + lambda *args, **kwargs: time.sleep(0) ) - assert client_mock._start_background_channel_refresh.call_count == 1 - assert client_mock._ping_and_warm_instances.call_count == len(mock_channels) - for channel in mock_channels: - assert channel in [ - call[0][0] - for call in client_mock._ping_and_warm_instances.call_args_list - ] - assert len(active_instances) == 2 - assert len(instance_owners) == 2 - expected_key2 = ( - "prefix/instance-2", - table_mock2.table_name, - table_mock2.app_profile_id, + kwargs["instance_id"] = kwargs.get( + "instance_id", args[0] if args else "instance" ) - assert any( - [ - expected_key2 == tuple(list(active_instances)[i]) - for i in range(len(active_instances)) - ] + kwargs["table_id"] = kwargs.get( + "table_id", args[1] if len(args) > 1 else "table" ) - assert any( - [ - expected_key2 == tuple(list(instance_owners)[i]) - for i in range(len(instance_owners)) - ] + client_mock._gapic_client.table_path.return_value = kwargs["table_id"] + client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] + return Table(client_mock, *args, **kwargs) + + def _make_stats(self): + from google.cloud.bigtable_v2.types import RequestStats + from google.cloud.bigtable_v2.types import FullReadStatsView + from google.cloud.bigtable_v2.types import ReadIterationStats + + return RequestStats( + full_read_stats_view=FullReadStatsView( + read_iteration_stats=ReadIterationStats( + rows_seen_count=1, + rows_returned_count=2, + cells_seen_count=3, + cells_returned_count=4, + ) + ) ) - @pytest.mark.parametrize( - "insert_instances,expected_active,expected_owner_keys", - [ - ([("i", "t", None)], [("i", "t", None)], [("i", "t", None)]), - ([("i", "t", "p")], [("i", "t", "p")], [("i", "t", "p")]), - ([("1", "t", "p"), ("1", "t", "p")], [("1", "t", "p")], [("1", "t", "p")]), - ( - [("1", "t", "p"), ("2", "t", "p")], - [("1", "t", "p"), ("2", "t", "p")], - [("1", "t", "p"), ("2", "t", "p")], - ), - ], - ) - def test__register_instance_state( - self, insert_instances, expected_active, expected_owner_keys + @staticmethod + def _make_chunk(*args, **kwargs): + from google.cloud.bigtable_v2 import ReadRowsResponse + + kwargs["row_key"] = kwargs.get("row_key", b"row_key") + kwargs["family_name"] = kwargs.get("family_name", "family_name") + kwargs["qualifier"] = kwargs.get("qualifier", b"qualifier") + kwargs["value"] = kwargs.get("value", b"value") + kwargs["commit_row"] = kwargs.get("commit_row", True) + return ReadRowsResponse.CellChunk(*args, **kwargs) + + @staticmethod + def _make_gapic_stream( + chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0 ): - """test that active_instances and instance_owners are updated as expected""" - client_mock = mock.Mock() - client_mock._gapic_client.instance_path.side_effect = lambda a, b: b - active_instances = set() - instance_owners = {} - client_mock._active_instances = active_instances - client_mock._instance_owners = instance_owners - client_mock._channel_refresh_tasks = [] - client_mock._start_background_channel_refresh.side_effect = ( - lambda: client_mock._channel_refresh_tasks.append(mock.Mock) - ) - mock_channels = [mock.Mock() for i in range(5)] - client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = mock.Mock() - table_mock = mock.Mock() - for instance, table, profile in insert_instances: - table_mock.table_name = table - table_mock.app_profile_id = profile - self._get_target_class()._register_instance( - client_mock, instance, table_mock - ) - assert len(active_instances) == len(expected_active) - assert len(instance_owners) == len(expected_owner_keys) - for expected in expected_active: - assert any( - [ - expected == tuple(list(active_instances)[i]) - for i in range(len(active_instances)) - ] + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list, sleep_time): + self.chunk_list = chunk_list + self.idx = -1 + self.sleep_time = sleep_time + + def __iter__(self): + return self + + def __next__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + if sleep_time: + time.sleep(self.sleep_time) + chunk = self.chunk_list[self.idx] + if isinstance(chunk, Exception): + raise chunk + else: + return ReadRowsResponse(chunks=[chunk]) + raise StopIteration + + def cancel(self): + pass + + return mock_stream(chunk_list, sleep_time) + + def execute_fn(self, table, *args, **kwargs): + return table.read_rows(*args, **kwargs) + + def test_read_rows(self): + query = ReadRowsQuery() + chunks = [ + self._make_chunk(row_key=b"test_1"), + self._make_chunk(row_key=b"test_2"), + ] + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks ) - for expected in expected_owner_keys: - assert any( - [ - expected == tuple(list(instance_owners)[i]) - for i in range(len(instance_owners)) - ] + results = self.execute_fn(table, query, operation_timeout=3) + assert len(results) == 2 + assert results[0].row_key == b"test_1" + assert results[1].row_key == b"test_2" + + def test_read_rows_stream(self): + query = ReadRowsQuery() + chunks = [ + self._make_chunk(row_key=b"test_1"), + self._make_chunk(row_key=b"test_2"), + ] + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks ) + gen = table.read_rows_stream(query, operation_timeout=3) + results = [row for row in gen] + assert len(results) == 2 + assert results[0].row_key == b"test_1" + assert results[1].row_key == b"test_2" - def test__remove_instance_registration(self): - client = self._make_client(project="project-id") - table = mock.Mock() - client._register_instance("instance-1", table) - client._register_instance("instance-2", table) - assert len(client._active_instances) == 2 - assert len(client._instance_owners.keys()) == 2 - instance_1_path = client._gapic_client.instance_path( - client.project, "instance-1" - ) - instance_1_key = (instance_1_path, table.table_name, table.app_profile_id) - instance_2_path = client._gapic_client.instance_path( - client.project, "instance-2" - ) - instance_2_key = (instance_2_path, table.table_name, table.app_profile_id) - assert len(client._instance_owners[instance_1_key]) == 1 - assert list(client._instance_owners[instance_1_key])[0] == id(table) - assert len(client._instance_owners[instance_2_key]) == 1 - assert list(client._instance_owners[instance_2_key])[0] == id(table) - success = client._remove_instance_registration("instance-1", table) - assert success - assert len(client._active_instances) == 1 - assert len(client._instance_owners[instance_1_key]) == 0 - assert len(client._instance_owners[instance_2_key]) == 1 - assert client._active_instances == {instance_2_key} - success = client._remove_instance_registration("fake-key", table) - assert not success - assert len(client._active_instances) == 1 - client.close() + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_read_rows_query_matches_request(self, include_app_profile): + from google.cloud.bigtable.data import RowRange + from google.cloud.bigtable.data.row_filters import PassAllFilter - def test__multiple_table_registration(self): - """ - registering with multiple tables with the same key should - add multiple owners to instance_owners, but only keep one copy - of shared key in active_instances - """ - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + app_profile_id = "app_profile_id" if include_app_profile else None + with self._make_table(app_profile_id=app_profile_id) as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream([]) + row_keys = [b"test_1", "test_2"] + row_ranges = RowRange("1start", "2end") + filter_ = PassAllFilter(True) + limit = 99 + query = ReadRowsQuery( + row_keys=row_keys, + row_ranges=row_ranges, + row_filter=filter_, + limit=limit, + ) + results = table.read_rows(query, operation_timeout=3) + assert len(results) == 0 + call_request = read_rows.call_args_list[0][0][0] + query_pb = query._to_pb(table) + assert call_request == query_pb - with self._make_client(project="project-id") as client: - with client.get_table("instance_1", "table_1") as table_1: - instance_1_path = client._gapic_client.instance_path( - client.project, "instance_1" - ) - instance_1_key = _WarmedInstanceKey( - instance_1_path, table_1.table_name, table_1.app_profile_id + @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) + def test_read_rows_timeout(self, operation_timeout): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + query = ReadRowsQuery() + chunks = [self._make_chunk(row_key=b"test_1")] + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=1 + ) + try: + table.read_rows(query, operation_timeout=operation_timeout) + except core_exceptions.DeadlineExceeded as e: + assert ( + e.message + == f"operation_timeout of {operation_timeout:0.1f}s exceeded" ) - assert len(client._instance_owners[instance_1_key]) == 1 - assert len(client._active_instances) == 1 - assert id(table_1) in client._instance_owners[instance_1_key] - with client.get_table("instance_1", "table_1") as table_2: - assert len(client._instance_owners[instance_1_key]) == 2 - assert len(client._active_instances) == 1 - assert id(table_1) in client._instance_owners[instance_1_key] - assert id(table_2) in client._instance_owners[instance_1_key] - with client.get_table("instance_1", "table_3") as table_3: - instance_3_path = client._gapic_client.instance_path( - client.project, "instance_1" - ) - instance_3_key = _WarmedInstanceKey( - instance_3_path, table_3.table_name, table_3.app_profile_id - ) - assert len(client._instance_owners[instance_1_key]) == 2 - assert len(client._instance_owners[instance_3_key]) == 1 - assert len(client._active_instances) == 2 - assert id(table_1) in client._instance_owners[instance_1_key] - assert id(table_2) in client._instance_owners[instance_1_key] - assert id(table_3) in client._instance_owners[instance_3_key] - assert len(client._active_instances) == 1 - assert instance_1_key in client._active_instances - assert id(table_2) not in client._instance_owners[instance_1_key] - assert len(client._active_instances) == 0 - assert instance_1_key not in client._active_instances - assert len(client._instance_owners[instance_1_key]) == 0 - def test__multiple_instance_registration(self): + @pytest.mark.parametrize( + "per_request_t, operation_t, expected_num", + [(0.05, 0.08, 2), (0.05, 0.54, 11), (0.05, 0.14, 3), (0.05, 0.24, 5)], + ) + def test_read_rows_attempt_timeout(self, per_request_t, operation_t, expected_num): """ - registering with multiple instance keys should update the key - in instance_owners and active_instances + Ensures that the attempt_timeout is respected and that the number of + requests is as expected. + + operation_timeout does not cancel the request, so we expect the number of + requests to be the ceiling of operation_timeout / attempt_timeout. """ - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - with self._make_client(project="project-id") as client: - with client.get_table("instance_1", "table_1") as table_1: - with client.get_table("instance_2", "table_2") as table_2: - instance_1_path = client._gapic_client.instance_path( - client.project, "instance_1" - ) - instance_1_key = _WarmedInstanceKey( - instance_1_path, table_1.table_name, table_1.app_profile_id - ) - instance_2_path = client._gapic_client.instance_path( - client.project, "instance_2" + expected_last_timeout = operation_t - (expected_num - 1) * per_request_t + with mock.patch("random.uniform", side_effect=lambda a, b: 0): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=per_request_t + ) + query = ReadRowsQuery() + chunks = [core_exceptions.DeadlineExceeded("mock deadline")] + try: + table.read_rows( + query, + operation_timeout=operation_t, + attempt_timeout=per_request_t, ) - instance_2_key = _WarmedInstanceKey( - instance_2_path, table_2.table_name, table_2.app_profile_id + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + if expected_num == 0: + assert retry_exc is None + else: + assert type(retry_exc) is RetryExceptionGroup + assert f"{expected_num} failed attempts" in str(retry_exc) + assert len(retry_exc.exceptions) == expected_num + for sub_exc in retry_exc.exceptions: + assert sub_exc.message == "mock deadline" + assert read_rows.call_count == expected_num + for _, call_kwargs in read_rows.call_args_list[:-1]: + assert call_kwargs["timeout"] == per_request_t + assert call_kwargs["retry"] is None + assert ( + abs( + read_rows.call_args_list[-1][1]["timeout"] + - expected_last_timeout ) - assert len(client._instance_owners[instance_1_key]) == 1 - assert len(client._instance_owners[instance_2_key]) == 1 - assert len(client._active_instances) == 2 - assert id(table_1) in client._instance_owners[instance_1_key] - assert id(table_2) in client._instance_owners[instance_2_key] - assert len(client._active_instances) == 1 - assert instance_1_key in client._active_instances - assert len(client._instance_owners[instance_2_key]) == 0 - assert len(client._instance_owners[instance_1_key]) == 1 - assert id(table_1) in client._instance_owners[instance_1_key] - assert len(client._active_instances) == 0 - assert len(client._instance_owners[instance_1_key]) == 0 - assert len(client._instance_owners[instance_2_key]) == 0 + < 0.05 + ) - def test_get_table(self): - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + @pytest.mark.parametrize( + "exc_type", + [ + core_exceptions.Aborted, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ], + ) + def test_read_rows_retryable_error(self, exc_type): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + try: + table.read_rows(query, operation_timeout=0.1) + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + root_cause = retry_exc.exceptions[0] + assert type(root_cause) is exc_type + assert root_cause == expected_error - client = self._make_client(project="project-id") - assert not client._active_instances - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - table = client.get_table( - expected_instance_id, expected_table_id, expected_app_profile_id - ) - time.sleep(0) - assert isinstance(table, Table) - assert table.table_id == expected_table_id - assert ( - table.table_name - == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" - ) - assert table.instance_id == expected_instance_id - assert ( - table.instance_name - == f"projects/{client.project}/instances/{expected_instance_id}" - ) - assert table.app_profile_id == expected_app_profile_id - assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id - ) - assert instance_key in client._active_instances - assert client._instance_owners[instance_key] == {id(table)} - client.close() + @pytest.mark.parametrize( + "exc_type", + [ + core_exceptions.Cancelled, + core_exceptions.PreconditionFailed, + core_exceptions.NotFound, + core_exceptions.PermissionDenied, + core_exceptions.Conflict, + core_exceptions.InternalServerError, + core_exceptions.TooManyRequests, + core_exceptions.ResourceExhausted, + InvalidChunk, + ], + ) + def test_read_rows_non_retryable_error(self, exc_type): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + try: + table.read_rows(query, operation_timeout=0.1) + except exc_type as e: + assert e == expected_error + + def test_read_rows_revise_request(self): + """Ensure that _revise_request is called between retries""" + from google.cloud.bigtable.data.exceptions import InvalidChunk + from google.cloud.bigtable_v2.types import RowSet + + return_val = RowSet() + with mock.patch.object( + self._get_operation_class(), "_revise_request_rowset" + ) as revise_rowset: + revise_rowset.return_value = return_val + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + row_keys = [b"test_1", b"test_2", b"test_3"] + query = ReadRowsQuery(row_keys=row_keys) + chunks = [ + self._make_chunk(row_key=b"test_1"), + core_exceptions.Aborted("mock retryable error"), + ] + try: + table.read_rows(query) + except InvalidChunk: + revise_rowset.assert_called() + first_call_kwargs = revise_rowset.call_args_list[0].kwargs + assert first_call_kwargs["row_set"] == query._to_pb(table).rows + assert first_call_kwargs["last_seen_row_key"] == b"test_1" + revised_call = read_rows.call_args_list[1].args[0] + assert revised_call.rows == return_val + + def test_read_rows_default_timeouts(self): + """Ensure that the default timeouts are set on the read rows operation when not overridden""" + operation_timeout = 8 + attempt_timeout = 4 + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: + mock_op.side_effect = RuntimeError("mock error") + with self._make_table( + default_read_rows_operation_timeout=operation_timeout, + default_read_rows_attempt_timeout=attempt_timeout, + ) as table: + try: + table.read_rows(ReadRowsQuery()) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["attempt_timeout"] == attempt_timeout + + def test_read_rows_default_timeout_override(self): + """When timeouts are passed, they overwrite default values""" + operation_timeout = 8 + attempt_timeout = 4 + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: + mock_op.side_effect = RuntimeError("mock error") + with self._make_table( + default_operation_timeout=99, default_attempt_timeout=97 + ) as table: + try: + table.read_rows( + ReadRowsQuery(), + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + ) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["attempt_timeout"] == attempt_timeout - def test_get_table_arg_passthrough(self): - """All arguments passed in get_table should be sent to constructor""" - with self._make_client(project="project-id") as client: - with mock.patch.object( - TestTable._get_target_class(), "__init__" - ) as mock_constructor: - mock_constructor.return_value = None - assert not client._active_instances - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - expected_args = (1, "test", {"test": 2}) - expected_kwargs = {"hello": "world", "test": 2} - client.get_table( - expected_instance_id, - expected_table_id, - expected_app_profile_id, - *expected_args, - **expected_kwargs, + def test_read_row(self): + """Test reading a single row""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + row = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, ) - mock_constructor.assert_called_once_with( - client, - expected_instance_id, - expected_table_id, - expected_app_profile_id, - *expected_args, - **expected_kwargs, + assert row == expected_result + assert read_rows.call_count == 1 + (args, kwargs) = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + + def test_read_row_w_filter(self): + """Test reading a single row with an added filter""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + mock_filter = mock.Mock() + expected_filter = {"filter": "mock filter"} + mock_filter._to_dict.return_value = expected_filter + row = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + row_filter=expected_filter, ) + assert row == expected_result + assert read_rows.call_count == 1 + (args, kwargs) = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + assert query.filter == expected_filter - def test_get_table_context_manager(self): - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + def test_read_row_no_response(self): + """should return None if row does not exist""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = lambda *args, **kwargs: [] + expected_op_timeout = 8 + expected_req_timeout = 4 + result = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert result is None + assert read_rows.call_count == 1 + (args, kwargs) = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - expected_project_id = "project-id" - with mock.patch.object(Table, "close") as close_mock: - with self._make_client(project=expected_project_id) as client: - with client.get_table( - expected_instance_id, expected_table_id, expected_app_profile_id - ) as table: - time.sleep(0) - assert isinstance(table, Table) - assert table.table_id == expected_table_id - assert ( - table.table_name - == f"projects/{expected_project_id}/instances/{expected_instance_id}/tables/{expected_table_id}" - ) - assert table.instance_id == expected_instance_id - assert ( - table.instance_name - == f"projects/{expected_project_id}/instances/{expected_instance_id}" - ) - assert table.app_profile_id == expected_app_profile_id - assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id + @pytest.mark.parametrize( + "return_value,expected_result", + [([], False), ([object()], True), ([object(), object()], True)], + ) + def test_row_exists(self, return_value, expected_result): + """Test checking for row existence""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = lambda *args, **kwargs: return_value + expected_op_timeout = 1 + expected_req_timeout = 2 + result = table.row_exists( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert expected_result == result + assert read_rows.call_count == 1 + (args, kwargs) = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + expected_filter = { + "chain": { + "filters": [ + {"cells_per_row_limit_filter": 1}, + {"strip_value_transformer": True}, + ] + } + } + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + assert query.filter._to_dict() == expected_filter + + +class TestReadRowsSharded(ABC): + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + def test_read_rows_sharded_empty_query(self): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as exc: + table.read_rows_sharded([]) + assert "empty sharded_query" in str(exc.value) + + def test_read_rows_sharded_multiple_queries(self): + """Test with multiple queries. Should return results from both""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "read_rows" + ) as read_rows: + read_rows.side_effect = ( + lambda *args, **kwargs: TestReadRows._make_gapic_stream( + [ + TestReadRows._make_chunk(row_key=k) + for k in args[0].rows.row_keys + ] + ) ) - assert instance_key in client._active_instances - assert client._instance_owners[instance_key] == {id(table)} - assert close_mock.call_count == 1 + query_1 = ReadRowsQuery(b"test_1") + query_2 = ReadRowsQuery(b"test_2") + result = table.read_rows_sharded([query_1, query_2]) + assert len(result) == 2 + assert result[0].row_key == b"test_1" + assert result[1].row_key == b"test_2" - def test_multiple_pool_sizes(self): - pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] - for pool_size in pool_sizes: - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - assert len(client._channel_refresh_tasks) == pool_size - client_duplicate = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - assert len(client_duplicate._channel_refresh_tasks) == pool_size - assert str(pool_size) in str(client.transport) - client.close() - client_duplicate.close() + @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24]) + def test_read_rows_sharded_multiple_queries_calls(self, n_queries): + """Each query should trigger a separate read_rows call""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + query_list = [ReadRowsQuery() for _ in range(n_queries)] + table.read_rows_sharded(query_list) + assert read_rows.call_count == n_queries - def test_close(self): - pool_size = 7 - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False + def test_read_rows_sharded_errors(self): + """Errors should be exposed as ShardedReadRowsExceptionGroups""" + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedQueryShardError + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = RuntimeError("mock error") + query_1 = ReadRowsQuery(b"test_1") + query_2 = ReadRowsQuery(b"test_2") + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + table.read_rows_sharded([query_1, query_2]) + exc_group = exc.value + assert isinstance(exc_group, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == 2 + assert isinstance(exc.value.exceptions[0], FailedQueryShardError) + assert isinstance(exc.value.exceptions[0].__cause__, RuntimeError) + assert exc.value.exceptions[0].index == 0 + assert exc.value.exceptions[0].query == query_1 + assert isinstance(exc.value.exceptions[1], FailedQueryShardError) + assert isinstance(exc.value.exceptions[1].__cause__, RuntimeError) + assert exc.value.exceptions[1].index == 1 + assert exc.value.exceptions[1].query == query_2 + + def test_read_rows_sharded_concurrent(self): + """Ensure sharded requests are concurrent""" + import time + + def mock_call(*args, **kwargs): + time.sleep(0.1) + return [mock.Mock()] + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(10)] + start_time = time.monotonic() + result = table.read_rows_sharded(queries) + call_time = time.monotonic() - start_time + assert read_rows.call_count == 10 + assert len(result) == 10 + assert call_time < 0.2 + + def test_read_rows_sharded_batching(self): + """ + Large queries should be processed in batches to limit concurrency + operation timeout should change between batches + """ + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT + + assert _CONCURRENCY_LIMIT == 10 + n_queries = 90 + expected_num_batches = n_queries // _CONCURRENCY_LIMIT + query_list = [ReadRowsQuery() for _ in range(n_queries)] + start_operation_timeout = 10 + start_attempt_timeout = 3 + client = self._make_client(use_emulator=True) + table = client.get_table( + "instance", + "table", + default_read_rows_operation_timeout=start_operation_timeout, + default_read_rows_attempt_timeout=start_attempt_timeout, ) - assert len(client._channel_refresh_tasks) == pool_size - tasks_list = list(client._channel_refresh_tasks) - for task in client._channel_refresh_tasks: - assert not task.done() - with mock.patch.object( - PooledBigtableGrpcTransport, "close", mock.Mock() - ) as close_mock: - client.close() - close_mock.assert_called_once() - close_mock.assert_called_once() - for task in tasks_list: - assert task.done() - assert client._channel_refresh_tasks == [] - def test_context_manager(self): - close_mock = mock.Mock() - true_close = None - with self._make_client(project="project-id") as client: - true_close = client.close() - client.close = close_mock - for task in client._channel_refresh_tasks: - assert not task.done() - assert client.project == "project-id" - assert client._active_instances == set() - close_mock.assert_not_called() - close_mock.assert_called_once() - close_mock.assert_called_once() - true_close + def mock_time_generator(start_op, _): + for i in range(0, 100000): + yield (start_op - i) + with mock.patch( + f"google.cloud.bigtable.data._helpers._attempt_timeout_generator" + ) as time_gen_mock: + time_gen_mock.side_effect = mock_time_generator + with mock.patch.object(table, "read_rows", mock.Mock()) as read_rows_mock: + read_rows_mock.return_value = [] + table.read_rows_sharded(query_list) + assert read_rows_mock.call_count == n_queries + kwargs = [ + read_rows_mock.call_args_list[idx][1] for idx in range(n_queries) + ] + for batch_idx in range(expected_num_batches): + batch_kwargs = kwargs[ + batch_idx + * _CONCURRENCY_LIMIT : (batch_idx + 1) + * _CONCURRENCY_LIMIT + ] + for req_kwargs in batch_kwargs: + expected_operation_timeout = start_operation_timeout - batch_idx + assert ( + req_kwargs["operation_timeout"] + == expected_operation_timeout + ) + expected_attempt_timeout = min( + start_attempt_timeout, expected_operation_timeout + ) + assert req_kwargs["attempt_timeout"] == expected_attempt_timeout -class TestTable(ABC): + +class TestSampleRowKeys(ABC): def _make_client(self, *args, **kwargs): return TestBigtableDataClient._make_client(*args, **kwargs) - @staticmethod - def _get_target_class(): - from google.cloud.bigtable.data._sync.client import Table + def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): + from google.cloud.bigtable_v2.types import SampleRowKeysResponse - return Table + for value in sample_list: + yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) - @property - def is_async(self): - return False + def test_sample_row_keys(self): + """Test that method returns the expected key samples""" + samples = [(b"test_1", 0), (b"test_2", 100), (b"test_3", 200)] + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", mock.Mock() + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream(samples) + result = table.sample_row_keys() + assert len(result) == 3 + assert all((isinstance(r, tuple) for r in result)) + assert all((isinstance(r[0], bytes) for r in result)) + assert all((isinstance(r[1], int) for r in result)) + assert result[0] == samples[0] + assert result[1] == samples[1] + assert result[2] == samples[2] - def test_table_ctor(self): - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + def test_sample_row_keys_bad_timeout(self): + """should raise error if timeout is negative""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.sample_row_keys(operation_timeout=-1) + assert "operation_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + table.sample_row_keys(attempt_timeout=-1) + assert "attempt_timeout must be greater than 0" in str(e.value) - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - expected_operation_timeout = 123 - expected_attempt_timeout = 12 - expected_read_rows_operation_timeout = 1.5 - expected_read_rows_attempt_timeout = 0.5 - expected_mutate_rows_operation_timeout = 2.5 - expected_mutate_rows_attempt_timeout = 0.75 - client = self._make_client() - assert not client._active_instances - table = self._get_target_class()( - client, - expected_instance_id, - expected_table_id, - expected_app_profile_id, - default_operation_timeout=expected_operation_timeout, - default_attempt_timeout=expected_attempt_timeout, - default_read_rows_operation_timeout=expected_read_rows_operation_timeout, - default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout, - default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, - default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, - ) - time.sleep(0) - assert table.table_id == expected_table_id - assert table.instance_id == expected_instance_id - assert table.app_profile_id == expected_app_profile_id - assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id - ) - assert instance_key in client._active_instances - assert client._instance_owners[instance_key] == {id(table)} - assert table.default_operation_timeout == expected_operation_timeout - assert table.default_attempt_timeout == expected_attempt_timeout - assert ( - table.default_read_rows_operation_timeout - == expected_read_rows_operation_timeout - ) - assert ( - table.default_read_rows_attempt_timeout - == expected_read_rows_attempt_timeout - ) - assert ( - table.default_mutate_rows_operation_timeout - == expected_mutate_rows_operation_timeout - ) - assert ( - table.default_mutate_rows_attempt_timeout - == expected_mutate_rows_attempt_timeout - ) - table._register_instance_future - assert table._register_instance_future.done() - assert not table._register_instance_future.cancelled() - assert table._register_instance_future.exception() is None - client.close() + def test_sample_row_keys_default_timeout(self): + """Should fallback to using table default operation_timeout""" + expected_timeout = 99 + with self._make_client() as client: + with client.get_table( + "i", + "t", + default_operation_timeout=expected_timeout, + default_attempt_timeout=expected_timeout, + ) as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", mock.Mock() + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream([]) + result = table.sample_row_keys() + (_, kwargs) = sample_row_keys.call_args + assert abs(kwargs["timeout"] - expected_timeout) < 0.1 + assert result == [] + assert kwargs["retry"] is None - def test_table_ctor_defaults(self): - """should provide default timeout values and app_profile_id""" - expected_table_id = "table-id" - expected_instance_id = "instance-id" - client = self._make_client() - assert not client._active_instances - table = Table(client, expected_instance_id, expected_table_id) - time.sleep(0) - assert table.table_id == expected_table_id - assert table.instance_id == expected_instance_id - assert table.app_profile_id is None - assert table.client is client - assert table.default_operation_timeout == 60 - assert table.default_read_rows_operation_timeout == 600 - assert table.default_mutate_rows_operation_timeout == 600 - assert table.default_attempt_timeout == 20 - assert table.default_read_rows_attempt_timeout == 20 - assert table.default_mutate_rows_attempt_timeout == 60 - client.close() + def test_sample_row_keys_gapic_params(self): + """make sure arguments are propagated to gapic call as expected""" + expected_timeout = 10 + expected_profile = "test1" + instance = "instance_name" + table_id = "my_table" + with self._make_client() as client: + with client.get_table( + instance, table_id, app_profile_id=expected_profile + ) as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", mock.Mock() + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream([]) + table.sample_row_keys(attempt_timeout=expected_timeout) + (args, kwargs) = sample_row_keys.call_args + assert len(args) == 0 + assert len(kwargs) == 5 + assert kwargs["timeout"] == expected_timeout + assert kwargs["app_profile_id"] == expected_profile + assert kwargs["table_name"] == table.table_name + assert kwargs["metadata"] is not None + assert kwargs["retry"] is None - def test_table_ctor_invalid_timeout_values(self): - """bad timeout values should raise ValueError""" - client = self._make_client() - timeout_pairs = [ - ("default_operation_timeout", "default_attempt_timeout"), - ( - "default_read_rows_operation_timeout", - "default_read_rows_attempt_timeout", - ), - ( - "default_mutate_rows_operation_timeout", - "default_mutate_rows_attempt_timeout", - ), - ] - for operation_timeout, attempt_timeout in timeout_pairs: - with pytest.raises(ValueError) as e: - Table(client, "", "", **{attempt_timeout: -1}) - assert "attempt_timeout must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - Table(client, "", "", **{operation_timeout: -1}) - assert "operation_timeout must be greater than 0" in str(e.value) - client.close() + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_sample_row_keys_retryable_errors(self, retryable_exception): + """retryable errors should be retried until timeout""" + from google.api_core.exceptions import DeadlineExceeded + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", mock.Mock() + ) as sample_row_keys: + sample_row_keys.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + table.sample_row_keys(operation_timeout=0.05) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert len(cause.exceptions) > 0 + assert isinstance(cause.exceptions[0], retryable_exception) @pytest.mark.parametrize( - "fn_name,fn_args,is_stream,extra_retryables", - [ - ("read_rows_stream", (ReadRowsQuery(),), True, ()), - ("read_rows", (ReadRowsQuery(),), True, ()), - ("read_row", (b"row_key",), True, ()), - ("read_rows_sharded", ([ReadRowsQuery()],), True, ()), - ("row_exists", (b"row_key",), True, ()), - ("sample_row_keys", (), False, ()), - ("mutate_row", (b"row_key", [mock.Mock()]), False, ()), - ( - "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), - False, - (_MutateRowsIncomplete,), - ), + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + core_exceptions.Aborted, ], ) + def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): + """non-retryable errors should cause a raise""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", mock.Mock() + ) as sample_row_keys: + sample_row_keys.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + table.sample_row_keys() + + +class TestMutateRow(ABC): + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + @pytest.mark.parametrize( - "input_retryables,expected_retryables", + "mutation_arg", [ - ( - TABLE_DEFAULT.READ_ROWS, - [ - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - core_exceptions.Aborted, - ], - ), - ( - TABLE_DEFAULT.DEFAULT, - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ), - ( - TABLE_DEFAULT.MUTATE_ROWS, - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + mutations.SetCell("family", b"qualifier", b"value"), + mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=1234567890 ), - ([], []), - ([4], [core_exceptions.DeadlineExceeded]), + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromFamily("family"), + mutations.DeleteAllFromRow(), + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromRow(), + ], ], ) - def test_customizable_retryable_errors( - self, - input_retryables, - expected_retryables, - fn_name, - fn_args, - is_stream, - extra_retryables, - ): - """ - Test that retryable functions support user-configurable arguments, and that the configured retryables are passed - down to the gapic layer. - """ - retry_fn = "retry_target" - if is_stream: - retry_fn += "_stream" - if self.is_async: - retry_fn += "_async" - with mock.patch(f"google.api_core.retry.{retry_fn}") as retry_fn_mock: - with self._make_client() as client: - table = client.get_table("instance-id", "table-id") - expected_predicate = lambda a: a in expected_retryables - retry_fn_mock.side_effect = RuntimeError("stop early") - with mock.patch( - "google.api_core.retry.if_exception_type" - ) as predicate_builder_mock: - predicate_builder_mock.return_value = expected_predicate - with pytest.raises(Exception): - test_fn = table.__getattribute__(fn_name) - test_fn(*fn_args, retryable_errors=input_retryables) - predicate_builder_mock.assert_called_once_with( - *expected_retryables, *extra_retryables + def test_mutate_row(self, mutation_arg): + """Test mutations with no errors""" + expected_attempt_timeout = 19 + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.return_value = None + table.mutate_row( + "row_key", + mutation_arg, + attempt_timeout=expected_attempt_timeout, ) - retry_call_args = retry_fn_mock.call_args_list[0].args - assert retry_call_args[1] is expected_predicate + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0].kwargs + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["row_key"] == b"row_key" + formatted_mutations = ( + [mutation._to_pb() for mutation in mutation_arg] + if isinstance(mutation_arg, list) + else [mutation_arg._to_pb()] + ) + assert kwargs["mutations"] == formatted_mutations + assert kwargs["timeout"] == expected_attempt_timeout + assert kwargs["retry"] is None @pytest.mark.parametrize( - "fn_name,fn_args,gapic_fn", + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_mutate_row_retryable_errors(self, retryable_exception): + from google.api_core.exceptions import DeadlineExceeded + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + mutation = mutations.DeleteAllFromRow() + assert mutation.is_idempotent() is True + table.mutate_row("row_key", mutation, operation_timeout=0.01) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_mutate_row_non_idempotent_retryable_errors(self, retryable_exception): + """Non-idempotent mutations should not be retried""" + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(retryable_exception): + mutation = mutations.SetCell( + "family", b"qualifier", b"value", -1 + ) + assert mutation.is_idempotent() is False + table.mutate_row("row_key", mutation, operation_timeout=0.2) + + @pytest.mark.parametrize( + "non_retryable_exception", [ - ("read_rows_stream", (ReadRowsQuery(),), "read_rows"), - ("read_rows", (ReadRowsQuery(),), "read_rows"), - ("read_row", (b"row_key",), "read_rows"), - ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), - ("row_exists", (b"row_key",), "read_rows"), - ("sample_row_keys", (), "sample_row_keys"), - ("mutate_row", (b"row_key", [mock.Mock()]), "mutate_row"), - ( - "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), - "mutate_rows", - ), - ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), - ( - "read_modify_write_row", - (b"row_key", mock.Mock()), - "read_modify_write_row", - ), + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + core_exceptions.Aborted, ], ) + def test_mutate_row_non_retryable_errors(self, non_retryable_exception): + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + mutation = mutations.SetCell( + "family", + b"qualifier", + b"value", + timestamp_micros=1234567890, + ) + assert mutation.is_idempotent() is True + table.mutate_row("row_key", mutation, operation_timeout=0.2) + @pytest.mark.parametrize("include_app_profile", [True, False]) - def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): - """check that all requests attach proper metadata headers""" + def test_mutate_row_metadata(self, include_app_profile): + """request should attach metadata headers""" profile = "profile" if include_app_profile else None - with mock.patch.object( - BigtableClient, gapic_fn, mock.mock.Mock() - ) as gapic_mock: - gapic_mock.side_effect = RuntimeError("stop early") - with self._make_client() as client: - table = Table(client, "instance-id", "table-id", profile) - try: - test_fn = table.__getattribute__(fn_name) - maybe_stream = test_fn(*fn_args) - [i for i in maybe_stream] - except Exception: - pass - kwargs = gapic_mock.call_args_list[0].kwargs + with self._make_client() as client: + with client.get_table("i", "t", app_profile_id=profile) as table: + with mock.patch.object( + client._gapic_client, "mutate_row", mock.Mock() + ) as read_rows: + table.mutate_row("rk", mock.Mock()) + kwargs = read_rows.call_args_list[0].kwargs metadata = kwargs["metadata"] goog_metadata = None for key, value in metadata: @@ -1164,3 +3772,609 @@ def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): assert "app_profile_id=profile" in goog_metadata else: assert "app_profile_id=" not in goog_metadata + + @pytest.mark.parametrize("mutations", [[], None]) + def test_mutate_row_no_mutations(self, mutations): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.mutate_row("key", mutations=mutations) + assert e.value.args[0] == "No mutations provided" + + +class TestBulkMutateRows(ABC): + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + def _mock_response(self, response_list): + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc import status_pb2 + + statuses = [] + for response in response_list: + if isinstance(response, core_exceptions.GoogleAPICallError): + statuses.append( + status_pb2.Status( + message=str(response), code=response.grpc_status_code.value[0] + ) + ) + else: + statuses.append(status_pb2.Status(code=0)) + entries = [ + MutateRowsResponse.Entry(index=i, status=statuses[i]) + for i in range(len(response_list)) + ] + + def generator(): + yield MutateRowsResponse(entries=entries) + + return generator() + + @pytest.mark.parametrize( + "mutation_arg", + [ + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=1234567890 + ) + ], + [mutations.DeleteRangeFromColumn("family", b"qualifier")], + [mutations.DeleteAllFromFamily("family")], + [mutations.DeleteAllFromRow()], + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromRow(), + ], + ], + ) + def test_bulk_mutate_rows(self, mutation_arg): + """Test mutations with no errors""" + expected_attempt_timeout = 19 + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.return_value = self._mock_response([None]) + bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg) + table.bulk_mutate_rows( + [bulk_mutation], attempt_timeout=expected_attempt_timeout + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"] == [bulk_mutation._to_pb()] + assert kwargs["timeout"] == expected_attempt_timeout + assert kwargs["retry"] is None + + def test_bulk_mutate_rows_multiple_entries(self): + """Test mutations with no errors""" + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.return_value = self._mock_response([None, None]) + mutation_list = [mutations.DeleteAllFromRow()] + entry_1 = mutations.RowMutationEntry(b"row_key_1", mutation_list) + entry_2 = mutations.RowMutationEntry(b"row_key_2", mutation_list) + table.bulk_mutate_rows([entry_1, entry_2]) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"][0] == entry_1._to_pb() + assert kwargs["entries"][1] == entry_2._to_pb() + + @pytest.mark.parametrize( + "exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_rows_idempotent_mutation_error_retryable(self, exception): + """Individual idempotent mutations should be retried if they fail with a retryable error""" + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], exception) + assert isinstance( + cause.exceptions[-1], core_exceptions.DeadlineExceeded + ) + + @pytest.mark.parametrize( + "exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + core_exceptions.Aborted, + ], + ) + def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable(self, exception): + """Individual idempotent mutations should not be retried if they fail with a non-retryable error""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, exception) + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_idempotent_retryable_request_errors(self, retryable_exception): + """Individual idempotent mutations should be retried if the request fails with a retryable error""" + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_rows_non_idempotent_retryable_errors( + self, retryable_exception + ): + """Non-Idempotent mutations should never be retried""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [retryable_exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", -1 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is False + table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, retryable_exception) + + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + ], + ) + def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception): + """If the request fails with a non-retryable error, mutations should not be retried""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, non_retryable_exception) + + def test_bulk_mutate_error_index(self): + """Test partial failure, partial success. Errors should be associated with the correct index""" + from google.api_core.exceptions import ( + DeadlineExceeded, + ServiceUnavailable, + FailedPrecondition, + ) + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = [ + self._mock_response([None, ServiceUnavailable("mock"), None]), + self._mock_response([DeadlineExceeded("mock")]), + self._mock_response([FailedPrecondition("final")]), + ] + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entries = [ + mutations.RowMutationEntry( + f"row_key_{i}".encode(), [mutation] + ) + for i in range(3) + ] + assert mutation.is_idempotent() is True + table.bulk_mutate_rows(entries, operation_timeout=1000) + assert len(e.value.exceptions) == 1 + failed = e.value.exceptions[0] + assert isinstance(failed, FailedMutationEntryError) + assert failed.index == 1 + assert failed.entry == entries[1] + cause = failed.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert len(cause.exceptions) == 3 + assert isinstance(cause.exceptions[0], ServiceUnavailable) + assert isinstance(cause.exceptions[1], DeadlineExceeded) + assert isinstance(cause.exceptions[2], FailedPrecondition) + + def test_bulk_mutate_error_recovery(self): + """If an error occurs, then resolves, no exception should be raised""" + from google.api_core.exceptions import DeadlineExceeded + + with self._make_client(project="project") as client: + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.side_effect = [ + self._mock_response([DeadlineExceeded("mock")]), + self._mock_response([None]), + ] + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entries = [ + mutations.RowMutationEntry(f"row_key_{i}".encode(), [mutation]) + for i in range(3) + ] + table.bulk_mutate_rows(entries, operation_timeout=1000) + + +class TestCheckAndMutateRow(ABC): + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + @pytest.mark.parametrize("gapic_result", [True, False]) + def test_check_and_mutate(self, gapic_result): + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + app_profile = "app_profile_id" + with self._make_client() as client: + with client.get_table( + "instance", "table", app_profile_id=app_profile + ) as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=gapic_result + ) + row_key = b"row_key" + predicate = None + true_mutations = [mock.Mock()] + false_mutations = [mock.Mock(), mock.Mock()] + operation_timeout = 0.2 + found = table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutations, + false_case_mutations=false_mutations, + operation_timeout=operation_timeout, + ) + assert found == gapic_result + kwargs = mock_gapic.call_args[1] + assert kwargs["table_name"] == table.table_name + assert kwargs["row_key"] == row_key + assert kwargs["predicate_filter"] == predicate + assert kwargs["true_mutations"] == [ + m._to_pb() for m in true_mutations + ] + assert kwargs["false_mutations"] == [ + m._to_pb() for m in false_mutations + ] + assert kwargs["app_profile_id"] == app_profile + assert kwargs["timeout"] == operation_timeout + assert kwargs["retry"] is None + + def test_check_and_mutate_bad_timeout(self): + """Should raise error if operation_timeout < 0""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=[mock.Mock()], + false_case_mutations=[], + operation_timeout=-1, + ) + assert str(e.value) == "operation_timeout must be greater than 0" + + def test_check_and_mutate_single_mutations(self): + """if single mutations are passed, they should be internally wrapped in a list""" + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + true_mutation = SetCell("family", b"qualifier", b"value") + false_mutation = SetCell("family", b"qualifier", b"value") + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["true_mutations"] == [true_mutation._to_pb()] + assert kwargs["false_mutations"] == [false_mutation._to_pb()] + + def test_check_and_mutate_predicate_object(self): + """predicate filter should be passed to gapic request""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + mock_predicate = mock.Mock() + predicate_pb = {"predicate": "dict"} + mock_predicate._to_pb.return_value = predicate_pb + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + table.check_and_mutate_row( + b"row_key", mock_predicate, false_case_mutations=[mock.Mock()] + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["predicate_filter"] == predicate_pb + assert mock_predicate._to_pb.call_count == 1 + assert kwargs["retry"] is None + + def test_check_and_mutate_mutations_parsing(self): + """mutations objects should be converted to protos""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + from google.cloud.bigtable.data.mutations import DeleteAllFromRow + + mutations = [mock.Mock() for _ in range(5)] + for idx, mutation in enumerate(mutations): + mutation._to_pb.return_value = f"fake {idx}" + mutations.append(DeleteAllFromRow()) + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=mutations[0:2], + false_case_mutations=mutations[2:], + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["true_mutations"] == ["fake 0", "fake 1"] + assert kwargs["false_mutations"] == [ + "fake 2", + "fake 3", + "fake 4", + DeleteAllFromRow()._to_pb(), + ] + assert all( + (mutation._to_pb.call_count == 1 for mutation in mutations[:5]) + ) + + +class TestReadModifyWriteRow(ABC): + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + @pytest.mark.parametrize( + "call_rules,expected_rules", + [ + ( + AppendValueRule("f", "c", b"1"), + [AppendValueRule("f", "c", b"1")._to_pb()], + ), + ( + [AppendValueRule("f", "c", b"1")], + [AppendValueRule("f", "c", b"1")._to_pb()], + ), + (IncrementRule("f", "c", 1), [IncrementRule("f", "c", 1)._to_pb()]), + ( + [AppendValueRule("f", "c", b"1"), IncrementRule("f", "c", 1)], + [ + AppendValueRule("f", "c", b"1")._to_pb(), + IncrementRule("f", "c", 1)._to_pb(), + ], + ), + ], + ) + def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): + """Test that the gapic call is called with given rules""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row("key", call_rules) + assert mock_gapic.call_count == 1 + found_kwargs = mock_gapic.call_args_list[0][1] + assert found_kwargs["rules"] == expected_rules + assert found_kwargs["retry"] is None + + @pytest.mark.parametrize("rules", [[], None]) + def test_read_modify_write_no_rules(self, rules): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.read_modify_write_row("key", rules=rules) + assert e.value.args[0] == "rules must contain at least one item" + + def test_read_modify_write_call_defaults(self): + instance = "instance1" + table_id = "table1" + project = "project1" + row_key = "row_key1" + with self._make_client(project=project) as client: + with client.get_table(instance, table_id) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert ( + kwargs["table_name"] + == f"projects/{project}/instances/{instance}/tables/{table_id}" + ) + assert kwargs["app_profile_id"] is None + assert kwargs["row_key"] == row_key.encode() + assert kwargs["timeout"] > 1 + + def test_read_modify_write_call_overrides(self): + row_key = b"row_key1" + expected_timeout = 12345 + profile_id = "profile1" + with self._make_client() as client: + with client.get_table( + "instance", "table_id", app_profile_id=profile_id + ) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row( + row_key, mock.Mock(), operation_timeout=expected_timeout + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["app_profile_id"] is profile_id + assert kwargs["row_key"] == row_key + assert kwargs["timeout"] == expected_timeout + + def test_read_modify_write_string_key(self): + row_key = "string_row_key1" + with self._make_client() as client: + with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["row_key"] == row_key.encode() + + def test_read_modify_write_row_building(self): + """results from gapic call should be used to construct row""" + from google.cloud.bigtable.data.row import Row + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + from google.cloud.bigtable_v2.types import Row as RowPB + + mock_response = ReadModifyWriteRowResponse(row=RowPB()) + with self._make_client() as client: + with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + with mock.patch.object(Row, "_from_pb") as constructor_mock: + mock_gapic.return_value = mock_response + table.read_modify_write_row("key", mock.Mock()) + assert constructor_mock.call_count == 1 + constructor_mock.assert_called_once_with(mock_response.row) From 751c67646c8f2232aa0a44222dc9298760e3189f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Apr 2024 17:08:20 -0700 Subject: [PATCH 043/360] fixed slow test --- tests/unit/data/_async/test__mutate_rows.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 81151f9b6..75c5bef0a 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -136,17 +136,14 @@ def test_ctor_too_many_entries(self): client = mock.Mock() table = mock.Mock() - entries = [self._make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + entries = [self._make_mutation()] * (_MUTATE_ROWS_REQUEST_MUTATION_LIMIT + 1) operation_timeout = 0.05 attempt_timeout = 0.01 - # no errors if at limit - self._make_one(client, table, entries, operation_timeout, attempt_timeout) - # raise error after crossing with pytest.raises(ValueError) as e: self._make_one( client, table, - entries + [self._make_mutation()], + entries, operation_timeout, attempt_timeout, ) From 97f479749772653ff3334071a03045406e7923af Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Apr 2024 17:23:12 -0700 Subject: [PATCH 044/360] got async read_rows acceptance working again --- tests/unit/data/_async/__init__.py | 0 .../data/_async/test_read_rows_acceptance.py | 519 +++++++++--------- tests/unit/data/_sync/__init__.py | 0 3 files changed, 261 insertions(+), 258 deletions(-) create mode 100644 tests/unit/data/_async/__init__.py create mode 100644 tests/unit/data/_sync/__init__.py diff --git a/tests/unit/data/_async/__init__.py b/tests/unit/data/_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 7cb3c08dc..b45ab75fc 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -14,11 +14,12 @@ from __future__ import annotations import os -from itertools import zip_longest - +import warnings import pytest import mock +from itertools import zip_longest + from google.cloud.bigtable_v2 import ReadRowsResponse from google.cloud.bigtable.data._async.client import BigtableDataClientAsync @@ -26,108 +27,75 @@ from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data.row import Row -from ..v2_client.test_row_merger import ReadRowsTest, TestFile +from ...v2_client.test_row_merger import ReadRowsTest, TestFile -def parse_readrows_acceptance_tests(): - dirname = os.path.dirname(__file__) - filename = os.path.join(dirname, "./read-rows-acceptance-test.json") +class TestReadRowsAcceptanceAsync: - with open(filename) as json_file: - test_json = TestFile.from_json(json_file.read()) - return test_json.read_rows_tests + def parse_readrows_acceptance_tests(): + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "../read-rows-acceptance-test.json") + with open(filename) as json_file: + test_json = TestFile.from_json(json_file.read()) + return test_json.read_rows_tests -def extract_results_from_row(row: Row): - results = [] - for family, col, cells in row.items(): - for cell in cells: - results.append( - ReadRowsTest.Result( - row_key=row.row_key, - family_name=family, - qualifier=col, - timestamp_micros=cell.timestamp_ns // 1000, - value=cell.value, - label=(cell.labels[0] if cell.labels else ""), + @staticmethod + def extract_results_from_row(row: Row): + results = [] + for family, col, cells in row.items(): + for cell in cells: + results.append( + ReadRowsTest.Result( + row_key=row.row_key, + family_name=family, + qualifier=col, + timestamp_micros=cell.timestamp_ns // 1000, + value=cell.value, + label=(cell.labels[0] if cell.labels else ""), + ) ) - ) - return results + return results + @staticmethod + async def _coro_wrapper(stream): + return stream -@pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description -) -@pytest.mark.asyncio -async def test_row_merger_scenario(test_case: ReadRowsTest): - async def _scenerio_stream(): - for chunk in test_case.chunks: - yield ReadRowsResponse(chunks=[chunk]) + async def _process_chunks(self, *chunks): + async def _row_stream(): + yield ReadRowsResponse(chunks=chunks) - try: - results = [] instance = mock.Mock() - instance._last_yielded_row_key = None instance._remaining_count = None + instance._last_yielded_row_key = None chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_scenerio_stream()) + instance, self._coro_wrapper(_row_stream()) ) merger = _ReadRowsOperationAsync.merge_rows(chunker) + results = [] async for row in merger: - for cell in row: - cell_result = ReadRowsTest.Result( - row_key=cell.row_key, - family_name=cell.family, - qualifier=cell.qualifier, - timestamp_micros=cell.timestamp_micros, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - -@pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description -) -@pytest.mark.asyncio -async def test_read_rows_scenario(test_case: ReadRowsTest): - async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): - from google.cloud.bigtable_v2 import ReadRowsResponse - - class mock_stream: - def __init__(self, chunk_list): - self.chunk_list = chunk_list - self.idx = -1 - - def __aiter__(self): - return self - - async def __anext__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - chunk = self.chunk_list[self.idx] - return ReadRowsResponse(chunks=[chunk]) - raise StopAsyncIteration - - def cancel(self): - pass - - return mock_stream(chunk_list) + results.append(row) + return results - try: - with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): - # use emulator mode to avoid auth issues in CI - client = BigtableDataClientAsync() - table = client.get_table("instance", "table") - results = [] - with mock.patch.object(table.client._gapic_client, "read_rows") as read_rows: - # run once, then return error on retry - read_rows.return_value = _make_gapic_stream(test_case.chunks) - async for row in await table.read_rows_stream(query={}): + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + @pytest.mark.asyncio + async def test_row_merger_scenario(self, test_case: ReadRowsTest): + async def _scenerio_stream(): + for chunk in test_case.chunks: + yield ReadRowsResponse(chunks=[chunk]) + + try: + results = [] + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + chunker = _ReadRowsOperationAsync.chunk_stream( + instance, self._coro_wrapper(_scenerio_stream()) + ) + merger = _ReadRowsOperationAsync.merge_rows(chunker) + async for row in merger: for cell in row: cell_result = ReadRowsTest.Result( row_key=cell.row_key, @@ -138,194 +106,229 @@ def cancel(self): label=cell.labels[0] if cell.labels else "", ) results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - finally: - await client.close() - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - -@pytest.mark.asyncio -async def test_out_of_order_rows(): - async def _row_stream(): - yield ReadRowsResponse(last_scanned_row_key=b"a") - - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = b"b" - chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_row_stream()) - ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) - with pytest.raises(InvalidChunk): - async for _ in merger: - pass + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected -@pytest.mark.asyncio -async def test_bare_reset(): - first_chunk = ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk( - row_key=b"a", family_name="f", qualifier=b"q", value=b"v" - ) + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, family_name="f") - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, value=b"v") - ), - ) - + @pytest.mark.asyncio + async def test_read_rows_scenario(self, test_case: ReadRowsTest): + async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list): + self.chunk_list = chunk_list + self.idx = -1 + + def __aiter__(self): + return self + + async def __anext__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + chunk = self.chunk_list[self.idx] + return ReadRowsResponse(chunks=[chunk]) + raise StopAsyncIteration + + def cancel(self): + pass + + return mock_stream(chunk_list) + + try: + with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # use emulator mode to avoid auth issues in CI + client = BigtableDataClientAsync() + table = client.get_table("instance", "table") + results = [] + with mock.patch.object(table.client._gapic_client, "read_rows") as read_rows: + # run once, then return error on retry + read_rows.return_value = _make_gapic_stream(test_case.chunks) + async for row in await table.read_rows_stream(query={}): + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + finally: + await client.close() + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + + @pytest.mark.asyncio + async def test_out_of_order_rows(self): + async def _row_stream(): + yield ReadRowsResponse(last_scanned_row_key=b"a") -@pytest.mark.asyncio -async def test_missing_family(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - qualifier=b"q", - timestamp_micros=1000, - value=b"v", - commit_row=True, - ) + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = b"b" + chunker = _ReadRowsOperationAsync.chunk_stream( + instance, self._coro_wrapper(_row_stream()) ) + merger = _ReadRowsOperationAsync.merge_rows(chunker) + with pytest.raises(InvalidChunk): + async for _ in merger: + pass -@pytest.mark.asyncio -async def test_mid_cell_row_key_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( + @pytest.mark.asyncio + async def test_bare_reset(self): + first_chunk = ReadRowsResponse.CellChunk( ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), + row_key=b"a", family_name="f", qualifier=b"q", value=b"v" + ) ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, family_name="f") + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, value=b"v") + ), + ) -@pytest.mark.asyncio -async def test_mid_cell_family_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(family_name="f2", value=b"v", commit_row=True), - ) - + @pytest.mark.asyncio + async def test_missing_family(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + qualifier=b"q", + timestamp_micros=1000, + value=b"v", + commit_row=True, + ) + ) -@pytest.mark.asyncio -async def test_mid_cell_qualifier_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(qualifier=b"q2", value=b"v", commit_row=True), - ) + @pytest.mark.asyncio + async def test_mid_cell_row_key_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), + ) -@pytest.mark.asyncio -async def test_mid_cell_timestamp_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - timestamp_micros=2000, value=b"v", commit_row=True - ), - ) + @pytest.mark.asyncio + async def test_mid_cell_family_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(family_name="f2", value=b"v", commit_row=True), + ) -@pytest.mark.asyncio -async def test_mid_cell_labels_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), - ) + @pytest.mark.asyncio + async def test_mid_cell_qualifier_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(qualifier=b"q2", value=b"v", commit_row=True), + ) -async def _coro_wrapper(stream): - return stream + @pytest.mark.asyncio + async def test_mid_cell_timestamp_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + timestamp_micros=2000, value=b"v", commit_row=True + ), + ) -async def _process_chunks(*chunks): - async def _row_stream(): - yield ReadRowsResponse(chunks=chunks) - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = None - chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_row_stream()) - ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) - results = [] - async for row in merger: - results.append(row) - return results + @pytest.mark.asyncio + async def test_mid_cell_labels_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), + ) diff --git a/tests/unit/data/_sync/__init__.py b/tests/unit/data/_sync/__init__.py new file mode 100644 index 000000000..e69de29bb From 92f21327f041796feb43adad5eaa118977912797 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Apr 2024 19:37:21 -0700 Subject: [PATCH 045/360] got acceptance tests working --- .../cloud/bigtable/data/_sync/unit_tests.yaml | 9 + .../data/_async/test_read_rows_acceptance.py | 37 +- tests/unit/data/_sync/test_autogen.py | 316 +++++++++++++++++- 3 files changed, 338 insertions(+), 24 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/unit_tests.yaml b/google/cloud/bigtable/data/_sync/unit_tests.yaml index 1475f94b1..6a4cb2159 100644 --- a/google/cloud/bigtable/data/_sync/unit_tests.yaml +++ b/google/cloud/bigtable/data/_sync/unit_tests.yaml @@ -103,5 +103,14 @@ classes: autogen_sync_name: TestCheckAndMutateRow - path: tests.unit.data._async.test_client.TestReadModifyWriteRowAsync autogen_sync_name: TestReadModifyWriteRow + - path: tests.unit.data._async.test_read_rows_acceptance.TestReadRowsAcceptanceAsync + autogen_sync_name: TestReadRowsAcceptance + replace_methods: + _get_operation_class: | + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + return _ReadRowsOperation + _get_client_class: | + from google.cloud.bigtable.data._sync.client import BigtableDataClient + return BigtableDataClient save_path: "tests/unit/data/_sync/test_autogen.py" diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index b45ab75fc..38c383204 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -22,16 +22,24 @@ from google.cloud.bigtable_v2 import ReadRowsResponse -from google.cloud.bigtable.data._async.client import BigtableDataClientAsync from google.cloud.bigtable.data.exceptions import InvalidChunk -from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data.row import Row -from ...v2_client.test_row_merger import ReadRowsTest, TestFile +from tests.unit.v2_client.test_row_merger import ReadRowsTest, TestFile class TestReadRowsAcceptanceAsync: + @staticmethod + def _get_operation_class(): + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + return _ReadRowsOperationAsync + + @staticmethod + def _get_client_class(): + from google.cloud.bigtable.data._async.client import BigtableDataClientAsync + return BigtableDataClientAsync + def parse_readrows_acceptance_tests(): dirname = os.path.dirname(__file__) filename = os.path.join(dirname, "../read-rows-acceptance-test.json") @@ -68,10 +76,10 @@ async def _row_stream(): instance = mock.Mock() instance._remaining_count = None instance._last_yielded_row_key = None - chunker = _ReadRowsOperationAsync.chunk_stream( + chunker = self._get_operation_class().chunk_stream( instance, self._coro_wrapper(_row_stream()) ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) + merger = self._get_operation_class().merge_rows(chunker) results = [] async for row in merger: results.append(row) @@ -91,10 +99,10 @@ async def _scenerio_stream(): instance = mock.Mock() instance._last_yielded_row_key = None instance._remaining_count = None - chunker = _ReadRowsOperationAsync.chunk_stream( + chunker = self._get_operation_class().chunk_stream( instance, self._coro_wrapper(_scenerio_stream()) ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) + merger = self._get_operation_class().merge_rows(chunker) async for row in merger: for cell in row: cell_result = ReadRowsTest.Result( @@ -139,13 +147,12 @@ def cancel(self): pass return mock_stream(chunk_list) - + with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # use emulator mode to avoid auth issues in CI + client = self._get_client_class()() try: - with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # use emulator mode to avoid auth issues in CI - client = BigtableDataClientAsync() table = client.get_table("instance", "table") results = [] with mock.patch.object(table.client._gapic_client, "read_rows") as read_rows: @@ -178,10 +185,10 @@ async def _row_stream(): instance = mock.Mock() instance._remaining_count = None instance._last_yielded_row_key = b"b" - chunker = _ReadRowsOperationAsync.chunk_stream( + chunker = self._get_operation_class().chunk_stream( instance, self._coro_wrapper(_row_stream()) ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) + merger = self._get_operation_class().merge_rows(chunker) with pytest.raises(InvalidChunk): async for _ in merger: pass diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py index 4aae958ed..49cc2dbbe 100644 --- a/tests/unit/data/_sync/test_autogen.py +++ b/tests/unit/data/_sync/test_autogen.py @@ -17,18 +17,23 @@ from __future__ import annotations from abc import ABC +from itertools import zip_longest from tests.unit.data._async.test__mutate_rows import TestMutateRowsOperation from tests.unit.data._async.test__read_rows import TestReadRowsOperation from tests.unit.data._async.test_mutations_batcher import Test_FlowControl +from tests.unit.v2_client.test_row_merger import ReadRowsTest +from tests.unit.v2_client.test_row_merger import TestFile from unittest import mock import asyncio import concurrent.futures import grpc import mock +import os import pytest import re import threading import time +import warnings from google.api_core import exceptions as core_exceptions from google.api_core import grpc_helpers @@ -42,6 +47,7 @@ from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.row import Row from google.cloud.bigtable_v2 import ReadRowsResponse from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( @@ -154,18 +160,11 @@ def test_ctor_too_many_entries(self): assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100000 client = mock.Mock() table = mock.Mock() - entries = [self._make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + entries = [self._make_mutation()] * (_MUTATE_ROWS_REQUEST_MUTATION_LIMIT + 1) operation_timeout = 0.05 attempt_timeout = 0.01 - self._make_one(client, table, entries, operation_timeout, attempt_timeout) with pytest.raises(ValueError) as e: - self._make_one( - client, - table, - entries + [self._make_mutation()], - operation_timeout, - attempt_timeout, - ) + self._make_one(client, table, entries, operation_timeout, attempt_timeout) assert "mutate_rows requests can contain at most 100000 mutations" in str( e.value ) @@ -4378,3 +4377,302 @@ def test_read_modify_write_row_building(self): table.read_modify_write_row("key", mock.Mock()) assert constructor_mock.call_count == 1 constructor_mock.assert_called_once_with(mock_response.row) + + +class TestReadRowsAcceptance(ABC): + @staticmethod + def _get_operation_class(): + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + + return _ReadRowsOperation + + @staticmethod + def _get_client_class(): + from google.cloud.bigtable.data._sync.client import BigtableDataClient + + return BigtableDataClient + + def parse_readrows_acceptance_tests(): + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "../read-rows-acceptance-test.json") + with open(filename) as json_file: + test_json = TestFile.from_json(json_file.read()) + return test_json.read_rows_tests + + @staticmethod + def extract_results_from_row(row: Row): + results = [] + for family, col, cells in row.items(): + for cell in cells: + results.append( + ReadRowsTest.Result( + row_key=row.row_key, + family_name=family, + qualifier=col, + timestamp_micros=cell.timestamp_ns // 1000, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + ) + return results + + @staticmethod + def _coro_wrapper(stream): + return stream + + def _process_chunks(self, *chunks): + def _row_stream(): + yield ReadRowsResponse(chunks=chunks) + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + results = [] + for row in merger: + results.append(row) + return results + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + def test_row_merger_scenario(self, test_case: ReadRowsTest): + def _scenerio_stream(): + for chunk in test_case.chunks: + yield ReadRowsResponse(chunks=[chunk]) + + try: + results = [] + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_scenerio_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + for row in merger: + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + def test_read_rows_scenario(self, test_case: ReadRowsTest): + def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list): + self.chunk_list = chunk_list + self.idx = -1 + + def __iter__(self): + return self + + def __next__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + chunk = self.chunk_list[self.idx] + return ReadRowsResponse(chunks=[chunk]) + raise StopIteration + + def cancel(self): + pass + + return mock_stream(chunk_list) + + with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + client = self._get_client_class()() + try: + table = client.get_table("instance", "table") + results = [] + with mock.patch.object( + table.client._gapic_client, "read_rows" + ) as read_rows: + read_rows.return_value = _make_gapic_stream(test_case.chunks) + for row in table.read_rows_stream(query={}): + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + finally: + client.close() + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + def test_out_of_order_rows(self): + def _row_stream(): + yield ReadRowsResponse(last_scanned_row_key=b"a") + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = b"b" + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + with pytest.raises(InvalidChunk): + for _ in merger: + pass + + def test_bare_reset(self): + first_chunk = ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk( + row_key=b"a", family_name="f", qualifier=b"q", value=b"v" + ) + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, family_name="f") + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, value=b"v") + ), + ) + + def test_missing_family(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + qualifier=b"q", + timestamp_micros=1000, + value=b"v", + commit_row=True, + ) + ) + + def test_mid_cell_row_key_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), + ) + + def test_mid_cell_family_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + family_name="f2", value=b"v", commit_row=True + ), + ) + + def test_mid_cell_qualifier_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + qualifier=b"q2", value=b"v", commit_row=True + ), + ) + + def test_mid_cell_timestamp_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + timestamp_micros=2000, value=b"v", commit_row=True + ), + ) + + def test_mid_cell_labels_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), + ) From c0064828297b601fb6dc165cf871a474fe40da64 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Apr 2024 19:46:01 -0700 Subject: [PATCH 046/360] ran blacken --- .../cloud/bigtable/data/_async/_read_rows.py | 4 +- google/cloud/bigtable/data/_async/client.py | 30 ++++-- google/cloud/bigtable/data/_sync/client.py | 19 ++-- .../bigtable/data/_sync/mutations_batcher.py | 2 - .../bigtable/transports/pooled_grpc.py | 15 ++- tests/unit/data/_async/test__mutate_rows.py | 1 - tests/unit/data/_async/test__read_rows.py | 4 +- tests/unit/data/_async/test_client.py | 102 ++++++++++++------ .../data/_async/test_mutations_batcher.py | 74 +++++++++---- .../data/_async/test_read_rows_acceptance.py | 25 +++-- 10 files changed, 185 insertions(+), 91 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index a6fe67847..b50cc9adc 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -190,9 +190,7 @@ async def chunk_stream( current_key = None @staticmethod - async def merge_rows( - chunks: AsyncIterable[ReadRowsResponsePB.CellChunk] | None - ): + async def merge_rows(chunks: AsyncIterable[ReadRowsResponsePB.CellChunk] | None): """ Merge chunks into rows """ diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index ac42b580a..aaed11cca 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -41,7 +41,9 @@ from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( PooledBigtableGrpcAsyncIOTransport, ) -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import PooledChannel as AsyncPooledChannel +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel as AsyncPooledChannel, +) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest from google.cloud.client import ClientWithProject from google.cloud.environment_vars import BIGTABLE_EMULATOR # type: ignore @@ -200,7 +202,11 @@ def _start_background_channel_refresh(self) -> None: Raises: - RuntimeError if not called in an asyncio event loop """ - if not self._channel_refresh_tasks and not self._emulator_host and not self._is_closed.is_set(): + if ( + not self._channel_refresh_tasks + and not self._emulator_host + and not self._is_closed.is_set() + ): # raise RuntimeError if there is no event loop asyncio.get_running_loop() for channel_idx in range(self.transport.pool_size): @@ -324,7 +330,10 @@ async def _manage_channel( # cycle channel out of use, with long grace window before closure start_timestamp = time.monotonic() await self.transport.replace_channel( - channel_idx, grace=grace_period, new_channel=new_channel, event=self._is_closed + channel_idx, + grace=grace_period, + new_channel=new_channel, + event=self._is_closed, ) # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) @@ -552,9 +561,9 @@ def __init__( default_mutate_rows_retryable_errors or () ) self.default_retryable_errors = default_retryable_errors or () - self._register_instance_future: asyncio.Future[None] = self._register_with_client() - - + self._register_instance_future: asyncio.Future[ + None + ] = self._register_with_client() def _register_with_client(self) -> asyncio.Future[None]: """ @@ -819,18 +828,17 @@ async def read_rows_sharded( ) return results_list - async def _shard_batch_helper(self, kwargs_list: list[dict]) -> list[list[Row] | BaseException]: + async def _shard_batch_helper( + self, kwargs_list: list[dict] + ) -> list[list[Row] | BaseException]: """ Helper function for executing a batch of read_rows queries in parallel Sync client implementation will override this method """ - routine_list = [ - self.read_rows(**kwargs) for kwargs in kwargs_list - ] + routine_list = [self.read_rows(**kwargs) for kwargs in kwargs_list] return await asyncio.gather(*routine_list, return_exceptions=True) - async def row_exists( self, row_key: str | bytes, diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 2d1d41814..bce9f671f 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -33,7 +33,7 @@ from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( PooledBigtableGrpcTransport, - PooledChannel + PooledChannel, ) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest @@ -44,7 +44,6 @@ class BigtableDataClient(BigtableDataClient_SyncGen): - @property def _executor(self) -> concurrent.futures.ThreadPoolExecutor: if not hasattr(self, "_executor_instance"): @@ -56,7 +55,11 @@ def _client_version() -> str: return f"{google.cloud.bigtable.__version__}-data" def _start_background_channel_refresh(self) -> None: - if not self._channel_refresh_tasks and not self._emulator_host and not self._is_closed.is_set(): + if ( + not self._channel_refresh_tasks + and not self._emulator_host + and not self._is_closed.is_set() + ): for channel_idx in range(self.transport.pool_size): self._channel_refresh_tasks.append( self._executor.submit(self._manage_channel, channel_idx) @@ -87,14 +90,18 @@ def close(self) -> None: class Table(Table_SyncGen): - def _register_with_client(self) -> concurrent.futures.Future[None]: return self.client._executor.submit( self.client._register_instance, self.instance_id, self ) - def _shard_batch_helper(self, kwargs_list: list[dict]) -> list[list[Row] | BaseException]: - futures_list = [self.client._executor.submit(self.read_rows, **kwargs) for kwargs in kwargs_list] + def _shard_batch_helper( + self, kwargs_list: list[dict] + ) -> list[list[Row] | BaseException]: + futures_list = [ + self.client._executor.submit(self.read_rows, **kwargs) + for kwargs in kwargs_list + ] results_list: list[list[Row] | BaseException] = [] for future in futures_list: if future.exception(): diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 7bd93873a..f1ea43af8 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -34,7 +34,6 @@ class _FlowControl(_FlowControl_SyncGen): class MutationsBatcher(MutationsBatcher_SyncGen): - @property def _executor(self): """ @@ -91,4 +90,3 @@ def _timer_routine(self, interval: float | None) -> None: self._closed.wait(timeout=interval) if not self._closed.is_set() and self._staged_entries: self._schedule_flush() - diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py index f852017c0..2c808a000 100644 --- a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py +++ b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py @@ -53,6 +53,7 @@ def with_call(self, *args, **kwargs): def future(self, *args, **kwargs): raise NotImplementedError() + class PooledUnaryUnaryMultiCallable(PooledMultiCallable, grpc.UnaryUnaryMultiCallable): def __call__(self, *args, **kwargs): return self.next_channel_fn().unary_unary( @@ -60,14 +61,18 @@ def __call__(self, *args, **kwargs): )(*args, **kwargs) -class PooledUnaryStreamMultiCallable(PooledMultiCallable, grpc.UnaryStreamMultiCallable): +class PooledUnaryStreamMultiCallable( + PooledMultiCallable, grpc.UnaryStreamMultiCallable +): def __call__(self, *args, **kwargs): return self.next_channel_fn().unary_stream( *self._init_args, **self._init_kwargs )(*args, **kwargs) -class PooledStreamUnaryMultiCallable(PooledMultiCallable, grpc.StreamUnaryMultiCallable): +class PooledStreamUnaryMultiCallable( + PooledMultiCallable, grpc.StreamUnaryMultiCallable +): def __call__(self, *args, **kwargs): return self.next_channel_fn().stream_unary( *self._init_args, **self._init_kwargs @@ -149,7 +154,9 @@ def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: def wait_for_state_change(self, last_observed_state): raise NotImplementedError() - def subscribe(self, callback, try_to_connect: bool = False) -> grpc.ChannelConnectivity: + def subscribe( + self, callback, try_to_connect: bool = False + ) -> grpc.ChannelConnectivity: raise NotImplementedError() def unsubscribe(self, callback): @@ -162,7 +169,7 @@ def replace_channel( Replaces a channel in the pool with a fresh one. The `new_channel` will start processing new requests immidiately, - but the old channel will continue serving existing clients for + but the old channel will continue serving existing clients for `grace` seconds Args: diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 75c5bef0a..26a9325f0 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -27,7 +27,6 @@ from mock import AsyncMock # type: ignore - class TestMutateRowsOperation: def _target_class(self): from google.cloud.bigtable.data._async._mutate_rows import ( diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index fab338fdf..58d7b30f4 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -354,7 +354,9 @@ async def mock_stream(): instance = mock.Mock() instance._last_yielded_row_key = None instance._remaining_count = None - stream = self._get_target_class().chunk_stream(instance, mock_awaitable_stream()) + stream = self._get_target_class().chunk_stream( + instance, mock_awaitable_stream() + ) await stream.__anext__() with pytest.raises(InvalidChunk) as exc: await stream.__anext__() diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index c2bbdcbb4..4506ee603 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -30,7 +30,9 @@ from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( PooledBigtableGrpcAsyncIOTransport, ) -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import PooledChannel as PooledChannelAsync +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel as PooledChannelAsync, +) from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.api_core import exceptions as core_exceptions from google.cloud.bigtable.data.exceptions import InvalidChunk @@ -50,7 +52,6 @@ class TestBigtableDataClientAsync: - @staticmethod def _get_target_class(): from google.cloud.bigtable.data._async.client import BigtableDataClientAsync @@ -67,7 +68,8 @@ def _make_client(cls, *args, use_emulator=True, **kwargs): if use_emulator: env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" import warnings - warnings.filterwarnings('ignore', category=RuntimeWarning) + + warnings.filterwarnings("ignore", category=RuntimeWarning) else: # set some default values kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) @@ -158,7 +160,9 @@ async def test_ctor_dict_options(self): with mock.patch.object( self._get_target_class(), "_start_background_channel_refresh" ) as start_background_refresh: - client = self._make_client(client_options=client_options, use_emulator=False) + client = self._make_client( + client_options=client_options, use_emulator=False + ) start_background_refresh.assert_called_once() await client.close() @@ -166,7 +170,9 @@ async def test_ctor_dict_options(self): async def test_veneer_grpc_headers(self): client_component = "data-async" if self.is_async else "data" VENEER_HEADER_REGEX = re.compile( - r"gapic\/[0-9]+\.[\w.-]+ gax\/[0-9]+\.[\w.-]+ gccl\/[0-9]+\.[\w.-]+-" + client_component + r" gl-python\/[0-9]+\.[\w.-]+ grpc\/[0-9]+\.[\w.-]+" + r"gapic\/[0-9]+\.[\w.-]+ gax\/[0-9]+\.[\w.-]+ gccl\/[0-9]+\.[\w.-]+-" + + client_component + + r" gl-python\/[0-9]+\.[\w.-]+ grpc\/[0-9]+\.[\w.-]+" ) # client_info should be populated with headers to @@ -194,7 +200,9 @@ async def test_veneer_grpc_headers(self): @pytest.mark.asyncio async def test_channel_pool_creation(self): pool_size = 14 - with mock.patch.object(grpc_helpers_async, "create_channel", AsyncMock()) as create_channel: + with mock.patch.object( + grpc_helpers_async, "create_channel", AsyncMock() + ) as create_channel: client = self._make_client(project="project-id", pool_size=pool_size) assert create_channel.call_count == pool_size await client.close() @@ -232,6 +240,7 @@ async def test_channel_pool_rotation(self): @pytest.mark.asyncio async def test_channel_pool_replace(self): import time + sleep_module = asyncio if self.is_async else time with mock.patch.object(sleep_module, "sleep"): pool_size = 7 @@ -281,8 +290,11 @@ async def test__start_background_channel_refresh_tasks_exist(self): @pytest.mark.parametrize("pool_size", [1, 3, 7]) async def test__start_background_channel_refresh(self, pool_size): import concurrent.futures + # should create background tasks for each channel - with mock.patch.object(self._get_target_class(), "_ping_and_warm_instances", AsyncMock()) as ping_and_warm: + with mock.patch.object( + self._get_target_class(), "_ping_and_warm_instances", AsyncMock() + ) as ping_and_warm: client = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False ) @@ -321,8 +333,14 @@ async def test__ping_and_warm_instances(self): test ping and warm with mocked asyncio.gather """ client_mock = mock.Mock() - client_mock._execute_ping_and_warms = lambda *args: self._get_target_class()._execute_ping_and_warms(client_mock, *args) - gather_tuple = (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + gather_tuple = ( + (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") + ) with mock.patch.object(*gather_tuple, AsyncMock()) as gather: if self.is_async: # simulate gather by returning the same number of items as passed in @@ -383,8 +401,14 @@ async def test__ping_and_warm_single_instance(self): should be able to call ping and warm with single instance """ client_mock = mock.Mock() - client_mock._execute_ping_and_warms = lambda *args: self._get_target_class()._execute_ping_and_warms(client_mock, *args) - gather_tuple = (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + gather_tuple = ( + (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") + ) with mock.patch.object(*gather_tuple, AsyncMock()) as gather: gather.side_effect = lambda *args, **kwargs: [mock.Mock() for _ in args] if self.is_async: @@ -438,7 +462,9 @@ async def test__manage_channel_first_sleep( with mock.patch.object(time, "monotonic") as monotonic: monotonic.return_value = 0 - sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + sleep_tuple = ( + (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + ) with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = asyncio.CancelledError try: @@ -523,7 +549,9 @@ async def test__manage_channel_sleeps( uniform.side_effect = lambda min_, max_: min_ with mock.patch.object(time, "time") as time_mock: time_mock.return_value = 0 - sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + sleep_tuple = ( + (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + ) with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = [None for i in range(num_cycles - 1)] + [ asyncio.CancelledError @@ -592,7 +620,9 @@ async def test__manage_channel_refresh(self, num_cycles): with mock.patch.object( PooledBigtableGrpcAsyncIOTransport, "replace_channel" ) as replace_channel: - sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + sleep_tuple = ( + (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + ) with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = [None for i in range(num_cycles)] + [ asyncio.CancelledError @@ -601,8 +631,12 @@ async def test__manage_channel_refresh(self, num_cycles): grpc_helpers_async, "create_channel" ) as create_channel: create_channel.return_value = new_channel - with mock.patch.object(self._get_target_class(), "_start_background_channel_refresh"): - client = self._make_client(project="project-id", use_emulator=False) + with mock.patch.object( + self._get_target_class(), "_start_background_channel_refresh" + ): + client = self._make_client( + project="project-id", use_emulator=False + ) create_channel.reset_mock() try: await client._manage_channel( @@ -913,7 +947,9 @@ async def test_get_table_arg_passthrough(self): All arguments passed in get_table should be sent to constructor """ async with self._make_client(project="project-id") as client: - with mock.patch.object(TestTableAsync._get_target_class(), "__init__") as mock_constructor: + with mock.patch.object( + TestTableAsync._get_target_class(), "__init__" + ) as mock_constructor: mock_constructor.return_value = None assert not client._active_instances expected_table_id = "table-id" @@ -947,7 +983,9 @@ async def test_get_table_context_manager(self): expected_app_profile_id = "app-profile-id" expected_project_id = "project-id" - with mock.patch.object(TestTableAsync._get_target_class(), "close") as close_mock: + with mock.patch.object( + TestTableAsync._get_target_class(), "close" + ) as close_mock: async with self._make_client(project=expected_project_id) as client: async with client.get_table( expected_instance_id, @@ -1060,13 +1098,13 @@ def test_client_ctor_sync(self): class TestTableAsync: - def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @staticmethod def _get_target_class(): from google.cloud.bigtable.data._async.client import TableAsync + return TableAsync @property @@ -1347,7 +1385,9 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ from google.cloud.bigtable.data import TableAsync profile = "profile" if include_app_profile else None - with mock.patch.object(BigtableAsyncClient, gapic_fn, mock.AsyncMock()) as gapic_mock: + with mock.patch.object( + BigtableAsyncClient, gapic_fn, mock.AsyncMock() + ) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") async with self._make_client() as client: table = TableAsync(client, "instance-id", "table-id", profile) @@ -1380,6 +1420,7 @@ class TestReadRowsAsync: @staticmethod def _get_operation_class(): from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + return _ReadRowsOperationAsync def _make_client(self, *args, **kwargs): @@ -1879,7 +1920,6 @@ async def test_row_exists(self, return_value, expected_result): class TestReadRowsShardedAsync: - def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2000,9 +2040,10 @@ async def test_read_rows_sharded_batching(self): client = self._make_client(use_emulator=True) table = client.get_table( - "instance", "table", + "instance", + "table", default_read_rows_operation_timeout=start_operation_timeout, - default_read_rows_attempt_timeout=start_attempt_timeout + default_read_rows_attempt_timeout=start_attempt_timeout, ) # make timeout generator that reduces timeout by one each call @@ -2010,7 +2051,9 @@ def mock_time_generator(start_op, _): for i in range(0, 100000): yield start_op - i - with mock.patch(f"google.cloud.bigtable.data._helpers._attempt_timeout_generator") as time_gen_mock: + with mock.patch( + f"google.cloud.bigtable.data._helpers._attempt_timeout_generator" + ) as time_gen_mock: time_gen_mock.side_effect = mock_time_generator with mock.patch.object(table, "read_rows", AsyncMock()) as read_rows_mock: @@ -2020,8 +2063,7 @@ def mock_time_generator(start_op, _): assert read_rows_mock.call_count == n_queries # ensure that timeouts decrease over time kwargs = [ - read_rows_mock.call_args_list[idx][1] - for idx in range(n_queries) + read_rows_mock.call_args_list[idx][1] for idx in range(n_queries) ] for batch_idx in range(expected_num_batches): batch_kwargs = kwargs[ @@ -2035,7 +2077,8 @@ def mock_time_generator(start_op, _): batch_idx ) assert ( - req_kwargs["operation_timeout"] == expected_operation_timeout + req_kwargs["operation_timeout"] + == expected_operation_timeout ) # each attempt_timeout should start with default value, but decrease when operation_timeout reaches it expected_attempt_timeout = min( @@ -2045,7 +2088,6 @@ def mock_time_generator(start_op, _): class TestSampleRowKeysAsync: - def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2197,7 +2239,6 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio class TestMutateRowAsync: - def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2373,7 +2414,6 @@ async def test_mutate_row_no_mutations(self, mutations): class TestBulkMutateRowsAsync: - def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2753,7 +2793,6 @@ async def test_bulk_mutate_error_recovery(self): class TestCheckAndMutateRowAsync: - def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2905,7 +2944,6 @@ async def test_check_and_mutate_mutations_parsing(self): class TestReadModifyWriteRowAsync: - def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index adb92b63f..65fc4766b 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -31,12 +31,12 @@ class Test_FlowControl: - @staticmethod def _target_class(): from google.cloud.bigtable.data._async.mutations_batcher import ( _FlowControlAsync, ) + return _FlowControlAsync def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): @@ -154,6 +154,7 @@ async def test_remove_from_flow_value_update( async def test__remove_from_flow_unlock(self): """capacity condition should notify after mutation is complete""" import inspect + instance = self._make_one(10, 10) instance._in_flight_mutation_count = 10 instance._in_flight_mutation_bytes = 10 @@ -163,6 +164,7 @@ async def task_routine(): await instance._capacity_condition.wait_for( lambda: instance._has_capacity(1, 1) ) + if inspect.iscoroutinefunction(task_routine): # for async class, build task to test flow unlock task = asyncio.create_task(task_routine()) @@ -170,6 +172,7 @@ async def task_routine(): else: # this branch will be tested in sync version of this test import threading + thread = threading.Thread(target=task_routine) thread.start() task_alive = thread.is_alive @@ -268,7 +271,9 @@ async def test_add_to_flow_max_mutation_limits( max_limit, ) with async_patch, sync_patch: - mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] + mutation_objs = [ + self._make_mutation(count=m[0], size=m[1]) for m in mutations + ] # flow control has no limits except API restrictions instance = self._make_one(float("inf"), float("inf")) i = 0 @@ -339,7 +344,9 @@ def _make_mutation(count=1, size=1): @pytest.mark.asyncio async def test_ctor_defaults(self): - with mock.patch.object(self._get_target_class(), "_timer_routine", return_value=asyncio.Future()) as flush_timer_mock: + with mock.patch.object( + self._get_target_class(), "_timer_routine", return_value=asyncio.Future() + ) as flush_timer_mock: table = mock.Mock() table.default_mutate_rows_operation_timeout = 10 table.default_mutate_rows_attempt_timeout = 8 @@ -363,10 +370,12 @@ async def test_ctor_defaults(self): == table.default_mutate_rows_operation_timeout ) assert ( - instance._attempt_timeout == table.default_mutate_rows_attempt_timeout + instance._attempt_timeout + == table.default_mutate_rows_attempt_timeout ) assert ( - instance._retryable_errors == table.default_mutate_rows_retryable_errors + instance._retryable_errors + == table.default_mutate_rows_retryable_errors ) await asyncio.sleep(0) assert flush_timer_mock.call_count == 1 @@ -376,7 +385,9 @@ async def test_ctor_defaults(self): @pytest.mark.asyncio async def test_ctor_explicit(self): """Test with explicit parameters""" - with mock.patch.object(self._get_target_class(), "_timer_routine", return_value=asyncio.Future()) as flush_timer_mock: + with mock.patch.object( + self._get_target_class(), "_timer_routine", return_value=asyncio.Future() + ) as flush_timer_mock: table = mock.Mock() flush_interval = 20 flush_limit_count = 17 @@ -409,7 +420,9 @@ async def test_ctor_explicit(self): instance._flow_control._max_mutation_count == flow_control_max_mutation_count ) - assert instance._flow_control._max_mutation_bytes == flow_control_max_bytes + assert ( + instance._flow_control._max_mutation_bytes == flow_control_max_bytes + ) assert instance._flow_control._in_flight_mutation_count == 0 assert instance._flow_control._in_flight_mutation_bytes == 0 assert instance._entries_processed_since_last_raise == 0 @@ -424,7 +437,9 @@ async def test_ctor_explicit(self): @pytest.mark.asyncio async def test_ctor_no_flush_limits(self): """Test with None for flush limits""" - with mock.patch.object(self._get_target_class(), "_timer_routine", return_value=asyncio.Future()) as flush_timer_mock: + with mock.patch.object( + self._get_target_class(), "_timer_routine", return_value=asyncio.Future() + ) as flush_timer_mock: table = mock.Mock() table.default_mutate_rows_operation_timeout = 10 table.default_mutate_rows_attempt_timeout = 8 @@ -495,7 +510,9 @@ def test_default_argument_consistency(self): @pytest.mark.parametrize("input_val", [None, 0, -1]) async def test__start_flush_timer_w_empty_input(self, input_val): """Empty/invalid timer should return immediately""" - with mock.patch.object(self._get_target_class(), "_schedule_flush") as flush_mock: + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: # mock different method depending on sync vs async async with self._make_one() as instance: if self.is_async(): @@ -510,9 +527,13 @@ async def test__start_flush_timer_w_empty_input(self, input_val): @pytest.mark.asyncio @pytest.mark.filterwarnings("ignore::RuntimeWarning") - async def test__start_flush_timer_call_when_closed(self,): + async def test__start_flush_timer_call_when_closed( + self, + ): """closed batcher's timer should return immediately""" - with mock.patch.object(self._get_target_class(), "_schedule_flush") as flush_mock: + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: async with self._make_one() as instance: await instance.close() flush_mock.reset_mock() @@ -531,7 +552,9 @@ async def test__start_flush_timer_call_when_closed(self,): @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test__flush_timer(self, num_staged): """Timer should continue to call _schedule_flush in a loop""" - with mock.patch.object(self._get_target_class(), "_schedule_flush") as flush_mock: + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: expected_sleep = 12 async with self._make_one(flush_interval=expected_sleep) as instance: loop_num = 3 @@ -544,7 +567,9 @@ async def test__flush_timer(self, num_staged): with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] with pytest.raises(TabError): - await self._get_target_class()._timer_routine(instance, expected_sleep) + await self._get_target_class()._timer_routine( + instance, expected_sleep + ) # replace with np-op so there are no issues on close instance._flush_timer = asyncio.Future() assert sleep_mock.call_count == loop_num + 1 @@ -807,7 +832,9 @@ async def test_flush_clears_job_list(self): and removed when it completes """ async with self._make_one() as instance: - with mock.patch.object(instance, "_flush_internal", AsyncMock()) as flush_mock: + with mock.patch.object( + instance, "_flush_internal", AsyncMock() + ) as flush_mock: if not self.is_async(): # simulate operation flush_mock.side_effect = lambda x: time.sleep(0.1) @@ -945,11 +972,14 @@ async def test__execute_mutate_rows_returns_errors(self): MutationsExceptionGroup, FailedMutationEntryError, ) + if self.is_async(): mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" else: mutate_path = "_sync._mutate_rows._MutateRowsOperation" - with mock.patch(f"google.cloud.bigtable.data.{mutate_path}.start") as mutate_rows: + with mock.patch( + f"google.cloud.bigtable.data.{mutate_path}.start" + ) as mutate_rows: err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) @@ -1075,7 +1105,9 @@ async def test_timeout_args_passed(self): mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" else: mutate_path = "_sync._mutate_rows._MutateRowsOperation" - with mock.patch(f"google.cloud.bigtable.data.{mutate_path}", return_value=AsyncMock()) as mutate_rows: + with mock.patch( + f"google.cloud.bigtable.data.{mutate_path}", return_value=AsyncMock() + ) as mutate_rows: expected_operation_timeout = 17 expected_attempt_timeout = 13 async with self._make_one( @@ -1182,8 +1214,14 @@ async def test_customizable_retryable_errors( Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - retryn_fn = "retry_target_async" if "Async" in self._get_target_class().__name__ else "retry_target" - with mock.patch.object(google.api_core.retry, "if_exception_type") as predicate_builder_mock: + retryn_fn = ( + "retry_target_async" + if "Async" in self._get_target_class().__name__ + else "retry_target" + ) + with mock.patch.object( + google.api_core.retry, "if_exception_type" + ) as predicate_builder_mock: with mock.patch.object(google.api_core.retry, retryn_fn) as retry_fn_mock: table = None with mock.patch("asyncio.create_task"): diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 38c383204..600c10b3b 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -29,15 +29,16 @@ class TestReadRowsAcceptanceAsync: - @staticmethod def _get_operation_class(): from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + return _ReadRowsOperationAsync @staticmethod def _get_client_class(): from google.cloud.bigtable.data._async.client import BigtableDataClientAsync + return BigtableDataClientAsync def parse_readrows_acceptance_tests(): @@ -119,7 +120,6 @@ async def _scenerio_stream(): for expected, actual in zip_longest(test_case.results, results): assert actual == expected - @pytest.mark.parametrize( "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description ) @@ -147,6 +147,7 @@ def cancel(self): pass return mock_stream(chunk_list) + with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -155,7 +156,9 @@ def cancel(self): try: table = client.get_table("instance", "table") results = [] - with mock.patch.object(table.client._gapic_client, "read_rows") as read_rows: + with mock.patch.object( + table.client._gapic_client, "read_rows" + ) as read_rows: # run once, then return error on retry read_rows.return_value = _make_gapic_stream(test_case.chunks) async for row in await table.read_rows_stream(query={}): @@ -176,7 +179,6 @@ def cancel(self): for expected, actual in zip_longest(test_case.results, results): assert actual == expected - @pytest.mark.asyncio async def test_out_of_order_rows(self): async def _row_stream(): @@ -193,7 +195,6 @@ async def _row_stream(): async for _ in merger: pass - @pytest.mark.asyncio async def test_bare_reset(self): first_chunk = ReadRowsResponse.CellChunk( @@ -244,7 +245,6 @@ async def test_bare_reset(self): ), ) - @pytest.mark.asyncio async def test_missing_family(self): with pytest.raises(InvalidChunk): @@ -258,7 +258,6 @@ async def test_missing_family(self): ) ) - @pytest.mark.asyncio async def test_mid_cell_row_key_change(self): with pytest.raises(InvalidChunk): @@ -274,7 +273,6 @@ async def test_mid_cell_row_key_change(self): ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), ) - @pytest.mark.asyncio async def test_mid_cell_family_change(self): with pytest.raises(InvalidChunk): @@ -287,10 +285,11 @@ async def test_mid_cell_family_change(self): value_size=2, value=b"v", ), - ReadRowsResponse.CellChunk(family_name="f2", value=b"v", commit_row=True), + ReadRowsResponse.CellChunk( + family_name="f2", value=b"v", commit_row=True + ), ) - @pytest.mark.asyncio async def test_mid_cell_qualifier_change(self): with pytest.raises(InvalidChunk): @@ -303,10 +302,11 @@ async def test_mid_cell_qualifier_change(self): value_size=2, value=b"v", ), - ReadRowsResponse.CellChunk(qualifier=b"q2", value=b"v", commit_row=True), + ReadRowsResponse.CellChunk( + qualifier=b"q2", value=b"v", commit_row=True + ), ) - @pytest.mark.asyncio async def test_mid_cell_timestamp_change(self): with pytest.raises(InvalidChunk): @@ -324,7 +324,6 @@ async def test_mid_cell_timestamp_change(self): ), ) - @pytest.mark.asyncio async def test_mid_cell_labels_change(self): with pytest.raises(InvalidChunk): From c39d26abd8e675d5fda998ea6d57b0dd293ab52e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Apr 2024 19:54:06 -0700 Subject: [PATCH 047/360] fixed some lint issues --- google/cloud/bigtable/data/_sync/client.py | 15 +-------------- tests/unit/data/_async/test__read_rows.py | 2 -- tests/unit/data/_async/test_client.py | 2 +- tests/unit/data/_async/test_mutations_batcher.py | 2 +- tests/unit/data/_sync/test_autogen.py | 2 +- 5 files changed, 4 insertions(+), 19 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index bce9f671f..ba8cfae47 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -14,11 +14,7 @@ # from __future__ import annotations -from typing import Any, TYPE_CHECKING - -import time -import random -import threading +from typing import TYPE_CHECKING import google.auth.credentials import concurrent.futures @@ -30,17 +26,8 @@ import google.cloud.bigtable.data._sync._read_rows # noqa: F401 import google.cloud.bigtable.data._sync._mutate_rows # noqa: F401 -from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledBigtableGrpcTransport, - PooledChannel, -) -from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest - if TYPE_CHECKING: - import grpc from google.cloud.bigtable.data.row import Row - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey class BigtableDataClient(BigtableDataClient_SyncGen): diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 58d7b30f4..a3decbf19 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -13,8 +13,6 @@ import pytest -from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - # try/except added for compatibility with python < 3.8 try: from unittest import mock diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 4506ee603..5009639aa 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -2052,7 +2052,7 @@ def mock_time_generator(start_op, _): yield start_op - i with mock.patch( - f"google.cloud.bigtable.data._helpers._attempt_timeout_generator" + "google.cloud.bigtable.data._helpers._attempt_timeout_generator" ) as time_gen_mock: time_gen_mock.side_effect = mock_time_generator diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 65fc4766b..072bc7545 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -168,7 +168,7 @@ async def task_routine(): if inspect.iscoroutinefunction(task_routine): # for async class, build task to test flow unlock task = asyncio.create_task(task_routine()) - task_alive = lambda: not task.done() + task_alive = lambda: not task.done() # noqa else: # this branch will be tested in sync version of this test import threading diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py index 49cc2dbbe..fda0fee15 100644 --- a/tests/unit/data/_sync/test_autogen.py +++ b/tests/unit/data/_sync/test_autogen.py @@ -3473,7 +3473,7 @@ def mock_time_generator(start_op, _): yield (start_op - i) with mock.patch( - f"google.cloud.bigtable.data._helpers._attempt_timeout_generator" + "google.cloud.bigtable.data._helpers._attempt_timeout_generator" ) as time_gen_mock: time_gen_mock.side_effect = mock_time_generator with mock.patch.object(table, "read_rows", mock.Mock()) as read_rows_mock: From d5903e29465cd6af173d8f15d30750310eee591e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 18 Apr 2024 09:57:39 -0700 Subject: [PATCH 048/360] fixed mypy issues --- .../cloud/bigtable/data/_async/mutations_batcher.py | 2 +- google/cloud/bigtable/data/_sync/_autogen.py | 4 ++-- google/cloud/bigtable/data/_sync/client.py | 6 ++++-- google/cloud/bigtable/data/_sync/mutations_batcher.py | 11 ++++++----- google/cloud/bigtable/data/_sync/sync_gen.yaml | 2 +- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 6d1fd8438..7faf10f24 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -221,7 +221,7 @@ def __init__( batch_retryable_errors, table ) - self._closed: bool = asyncio.Event() + self._closed: asyncio.Event = asyncio.Event() self._table = table self._staged_entries: list[RowMutationEntry] = [] self._staged_count, self._staged_bytes = 0, 0 diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index 5486d376b..a0bb005b7 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -551,7 +551,7 @@ def __init__( self._retryable_errors: list[type[Exception]] = _get_retryable_errors( batch_retryable_errors, table ) - self._closed: bool = threading.Event() + self._closed: threading.Event = threading.Event() self._table = table self._staged_entries: list[RowMutationEntry] = [] (self._staged_count, self._staged_bytes) = (0, 0) @@ -931,7 +931,7 @@ def __init__( self._active_instances: Set[_helpers._WarmedInstanceKey] = set() self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() - self._channel_refresh_tasks: list[threading.Thread[None]] = [] + self._channel_refresh_tasks: list[concurrent.futures.Future[None]] = [] if self._emulator_host is not None: warnings.warn( "Connecting to Bigtable emulator at {}".format(self._emulator_host), diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index ba8cfae47..c591b91c0 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -54,7 +54,7 @@ def _start_background_channel_refresh(self) -> None: def _execute_ping_and_warms(self, *fns) -> list[BaseException | None]: futures_list = [self._executor.submit(f) for f in fns] - results_list = [] + results_list: list[BaseException | None] = [] for future in futures_list: try: future.result() @@ -94,7 +94,9 @@ def _shard_batch_helper( if future.exception(): results_list.append(future.exception()) else: - results_list.append(future.result()) + result = future.result() + if result is not None: + results_list.append(result) return results_list def __enter__(self): diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index f1ea43af8..5a60e0831 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -66,14 +66,15 @@ def _wait_for_batch_results( ) -> list[Exception]: if not tasks: return [] - exceptions = [] + exceptions: list[Exception] = [] for task in tasks: try: exc_list = task.result() - for exc in exc_list: - # strip index information - exc.index = None - exceptions.extend(exc_list) + if exc_list: + for exc in exc_list: + # strip index information + exc.index = None + exceptions.extend(exc_list) except Exception as e: exceptions.append(e) return exceptions diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml index 850369331..aa5282e69 100644 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ b/google/cloud/bigtable/data/_sync/sync_gen.yaml @@ -3,7 +3,7 @@ asyncio_replacements: # Replace asyncio functionaility Queue: queue.Queue Condition: threading.Condition Future: concurrent.futures.Future - Task: threading.Thread + Task: concurrent.futures.Future Event: threading.Event text_replacements: # Find and replace specific text patterns From f9862a2db721dea16e952b1976c99359645aa4b9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 18 Apr 2024 10:48:57 -0700 Subject: [PATCH 049/360] refactored system tests into class --- tests/system/data/setup_fixtures.py | 23 - tests/system/data/test_system.py | 1714 ++++++++++++++------------- 2 files changed, 872 insertions(+), 865 deletions(-) diff --git a/tests/system/data/setup_fixtures.py b/tests/system/data/setup_fixtures.py index 77086b7f3..11013938b 100644 --- a/tests/system/data/setup_fixtures.py +++ b/tests/system/data/setup_fixtures.py @@ -23,14 +23,6 @@ import uuid -@pytest.fixture(scope="session") -def event_loop(): - loop = asyncio.get_event_loop() - yield loop - loop.stop() - loop.close() - - @pytest.fixture(scope="session") def admin_client(): """ @@ -150,22 +142,7 @@ def table_id( print(f"Table {init_table_id} not found, skipping deletion") -@pytest_asyncio.fixture(scope="session") -async def client(): - from google.cloud.bigtable.data import BigtableDataClientAsync - - project = os.getenv("GOOGLE_CLOUD_PROJECT") or None - async with BigtableDataClientAsync(project=project, pool_size=4) as client: - yield client - - @pytest.fixture(scope="session") def project_id(client): """Returns the project ID from the client.""" yield client.project - - -@pytest_asyncio.fixture(scope="session") -async def table(client, table_id, instance_id): - async with client.get_table(instance_id, table_id) as table: - yield table diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py index aeb08fc1a..3a229fc28 100644 --- a/tests/system/data/test_system.py +++ b/tests/system/data/test_system.py @@ -27,40 +27,6 @@ TEST_FAMILY_2 = "test-family-2" -@pytest.fixture(scope="session") -def column_family_config(): - """ - specify column families to create when creating a new test table - """ - from google.cloud.bigtable_admin_v2 import types - - return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} - - -@pytest.fixture(scope="session") -def init_table_id(): - """ - The table_id to use when creating a new test table - """ - return f"test-table-{uuid.uuid4().hex}" - - -@pytest.fixture(scope="session") -def cluster_config(project_id): - """ - Configuration for the clusters to use when creating a new instance - """ - from google.cloud.bigtable_admin_v2 import types - - cluster = { - "test-cluster": types.Cluster( - location=f"projects/{project_id}/locations/us-central1-b", - serve_nodes=1, - ) - } - return cluster - - class TempRowBuilder: """ Used to add rows to a table for testing purposes. @@ -105,839 +71,903 @@ async def delete_rows(self): await self.table.client._gapic_client.mutate_rows(request) -@pytest.mark.usefixtures("table") -async def _retrieve_cell_value(table, row_key): - """ - Helper to read an individual row - """ - from google.cloud.bigtable.data import ReadRowsQuery - - row_list = await table.read_rows(ReadRowsQuery(row_keys=row_key)) - assert len(row_list) == 1 - row = row_list[0] - cell = row.cells[0] - return cell.value - - -async def _create_row_and_mutation( - table, temp_rows, *, start_value=b"start", new_value=b"new_value" -): - """ - Helper to create a new row, and a sample set_cell mutation to change its value - """ - from google.cloud.bigtable.data.mutations import SetCell - - row_key = uuid.uuid4().hex.encode() - family = TEST_FAMILY - qualifier = b"test-qualifier" - await temp_rows.add_row( - row_key, family=family, qualifier=qualifier, value=start_value +class TestSystemAsync: + @pytest_asyncio.fixture(scope="session") + async def client(self): + from google.cloud.bigtable.data import BigtableDataClientAsync + + project = os.getenv("GOOGLE_CLOUD_PROJECT") or None + async with BigtableDataClientAsync(project=project, pool_size=4) as client: + yield client + + @pytest_asyncio.fixture(scope="session") + async def table(self, client, table_id, instance_id): + async with client.get_table( + instance_id, + table_id, + ) as table: + yield table + + @pytest.fixture(scope="session") + def event_loop(self): + loop = asyncio.get_event_loop() + yield loop + loop.stop() + loop.close() + + @pytest.fixture(scope="session") + def column_family_config(self): + """ + specify column families to create when creating a new test table + """ + from google.cloud.bigtable_admin_v2 import types + + return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} + + @pytest.fixture(scope="session") + def init_table_id(self): + """ + The table_id to use when creating a new test table + """ + return f"test-table-{uuid.uuid4().hex}" + + @pytest.fixture(scope="session") + def cluster_config(self, project_id): + """ + Configuration for the clusters to use when creating a new instance + """ + from google.cloud.bigtable_admin_v2 import types + + cluster = { + "test-cluster": types.Cluster( + location=f"projects/{project_id}/locations/us-central1-b", + serve_nodes=1, + ) + } + return cluster + + @pytest.mark.usefixtures("table") + async def _retrieve_cell_value(self, table, row_key): + """ + Helper to read an individual row + """ + from google.cloud.bigtable.data import ReadRowsQuery + + row_list = await table.read_rows(ReadRowsQuery(row_keys=row_key)) + assert len(row_list) == 1 + row = row_list[0] + cell = row.cells[0] + return cell.value + + async def _create_row_and_mutation( + self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" + ): + """ + Helper to create a new row, and a sample set_cell mutation to change its value + """ + from google.cloud.bigtable.data.mutations import SetCell + + row_key = uuid.uuid4().hex.encode() + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row( + row_key, family=family, qualifier=qualifier, value=start_value + ) + # ensure cell is initialized + assert (await self._retrieve_cell_value(table, row_key)) == start_value + + mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) + return row_key, mutation + + @pytest_asyncio.fixture(scope="function") + async def temp_rows(self, table): + builder = TempRowBuilder(table) + yield builder + await builder.delete_rows() + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=10 ) - # ensure cell is initialized - assert (await _retrieve_cell_value(table, row_key)) == start_value - - mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) - return row_key, mutation - - -@pytest.mark.usefixtures("table") -@pytest_asyncio.fixture(scope="function") -async def temp_rows(table): - builder = TempRowBuilder(table) - yield builder - await builder.delete_rows() - - -@pytest.mark.usefixtures("table") -@pytest.mark.usefixtures("client") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=10) -@pytest.mark.asyncio -async def test_ping_and_warm_gapic(client, table): - """ - Simple ping rpc test - This test ensures channels are able to authenticate with backend - """ - request = {"name": table.instance_name} - await client._gapic_client.ping_and_warm(request) - - -@pytest.mark.usefixtures("table") -@pytest.mark.usefixtures("client") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_ping_and_warm(client, table): - """ - Test ping and warm from handwritten client - """ - try: - channel = client.transport._grpc_channel.pool[0] - except Exception: - # for sync client - channel = client.transport._grpc_channel - results = await client._ping_and_warm_instances(channel) - assert len(results) == 1 - assert results[0] is None - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -async def test_mutation_set_cell(table, temp_rows): - """ - Ensure cells can be set properly - """ - row_key = b"bulk_mutate" - new_value = uuid.uuid4().hex.encode() - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value + @pytest.mark.asyncio + async def test_ping_and_warm_gapic(self, client, table): + """ + Simple ping rpc test + This test ensures channels are able to authenticate with backend + """ + request = {"name": table.instance_name} + await client._gapic_client.ping_and_warm(request) + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - await table.mutate_row(row_key, mutation) - - # ensure cell is updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - - -@pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" -) -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_sample_row_keys(client, table, temp_rows, column_split_config): - """ - Sample keys should return a single sample in small test tables - """ - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - - results = await table.sample_row_keys() - assert len(results) == len(column_split_config) + 1 - # first keys should match the split config - for idx in range(len(column_split_config)): - assert results[idx][0] == column_split_config[idx] - assert isinstance(results[idx][1], int) - # last sample should be empty key - assert results[-1][0] == b"" - assert isinstance(results[-1][1], int) - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_bulk_mutations_set_cell(client, table, temp_rows): - """ - Ensure cells can be set properly - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value + @pytest.mark.asyncio + async def test_ping_and_warm(self, client, table): + """ + Test ping and warm from handwritten client + """ + try: + channel = client.transport._grpc_channel.pool[0] + except Exception: + # for sync client + channel = client.transport._grpc_channel + results = await client._ping_and_warm_instances(channel) + assert len(results) == 1 + assert results[0] is None + + @pytest.mark.asyncio + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - - await table.bulk_mutate_rows([bulk_mutation]) - - # ensure cell is updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - + async def test_mutation_set_cell(self, table, temp_rows): + """ + Ensure cells can be set properly + """ + row_key = b"bulk_mutate" + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + await table.mutate_row(row_key, mutation) -@pytest.mark.asyncio -async def test_bulk_mutations_raise_exception(client, table): - """ - If an invalid mutation is passed, an exception should be raised - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell - from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup - from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + # ensure cell is updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value - row_key = uuid.uuid4().hex.encode() - mutation = SetCell(family="nonexistent", qualifier=b"test-qualifier", new_value=b"") - bulk_mutation = RowMutationEntry(row_key, [mutation]) + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" + ) + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @pytest.mark.asyncio + async def test_sample_row_keys(self, client, table, temp_rows, column_split_config): + """ + Sample keys should return a single sample in small test tables + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + + results = await table.sample_row_keys() + assert len(results) == len(column_split_config) + 1 + # first keys should match the split config + for idx in range(len(column_split_config)): + assert results[idx][0] == column_split_config[idx] + assert isinstance(results[idx][1], int) + # last sample should be empty key + assert results[-1][0] == b"" + assert isinstance(results[-1][1], int) + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.asyncio + async def test_bulk_mutations_set_cell(self, client, table, temp_rows): + """ + Ensure cells can be set properly + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) - with pytest.raises(MutationsExceptionGroup) as exc: await table.bulk_mutate_rows([bulk_mutation]) - assert len(exc.value.exceptions) == 1 - entry_error = exc.value.exceptions[0] - assert isinstance(entry_error, FailedMutationEntryError) - assert entry_error.index == 0 - assert entry_error.entry == bulk_mutation - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_mutations_batcher_context_manager(client, table, temp_rows): - """ - test batcher with context manager. Should flush on exit - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - row_key2, mutation2 = await _create_row_and_mutation( - table, temp_rows, new_value=new_value2 + # ensure cell is updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + + @pytest.mark.asyncio + async def test_bulk_mutations_raise_exception(self, client, table): + """ + If an invalid mutation is passed, an exception should be raised + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + row_key = uuid.uuid4().hex.encode() + mutation = SetCell( + family="nonexistent", qualifier=b"test-qualifier", new_value=b"" + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + + with pytest.raises(MutationsExceptionGroup) as exc: + await table.bulk_mutate_rows([bulk_mutation]) + assert len(exc.value.exceptions) == 1 + entry_error = exc.value.exceptions[0] + assert isinstance(entry_error, FailedMutationEntryError) + assert entry_error.index == 0 + assert entry_error.entry == bulk_mutation + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - async with table.mutations_batcher() as batcher: - await batcher.append(bulk_mutation) - await batcher.append(bulk_mutation2) - # ensure cell is updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - assert len(batcher._staged_entries) == 0 - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_mutations_batcher_timer_flush(client, table, temp_rows): - """ - batch should occur after flush_interval seconds - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry + @pytest.mark.asyncio + async def test_mutations_batcher_context_manager(self, client, table, temp_rows): + """ + test batcher with context manager. Should flush on exit + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + row_key2, mutation2 = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - new_value = uuid.uuid4().hex.encode() - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - flush_interval = 0.1 - async with table.mutations_batcher(flush_interval=flush_interval) as batcher: - await batcher.append(bulk_mutation) - await asyncio.sleep(0) - assert len(batcher._staged_entries) == 1 - await asyncio.sleep(flush_interval + 0.1) - assert len(batcher._staged_entries) == 0 + async with table.mutations_batcher() as batcher: + await batcher.append(bulk_mutation) + await batcher.append(bulk_mutation2) # ensure cell is updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_mutations_batcher_count_flush(client, table, temp_rows): - """ - batch should flush after flush_limit_mutation_count mutations - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry + assert (await self._retrieve_cell_value(table, row_key)) == new_value + assert len(batcher._staged_entries) == 0 - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = await _create_row_and_mutation( - table, temp_rows, new_value=new_value2 + @pytest.mark.asyncio + async def test_mutations_batcher_timer_flush(self, client, table, temp_rows): + """ + batch should occur after flush_interval seconds + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + flush_interval = 0.1 + async with table.mutations_batcher(flush_interval=flush_interval) as batcher: + await batcher.append(bulk_mutation) + await asyncio.sleep(0) + assert len(batcher._staged_entries) == 1 + await asyncio.sleep(flush_interval + 0.1) + assert len(batcher._staged_entries) == 0 + # ensure cell is updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - async with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: - await batcher.append(bulk_mutation) - assert len(batcher._flush_jobs) == 0 - # should be noop; flush not scheduled - assert len(batcher._staged_entries) == 1 - await batcher.append(bulk_mutation2) - # task should now be scheduled - assert len(batcher._flush_jobs) == 1 - await asyncio.gather(*batcher._flush_jobs) - assert len(batcher._staged_entries) == 0 - assert len(batcher._flush_jobs) == 0 - # ensure cells were updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - assert (await _retrieve_cell_value(table, row_key2)) == new_value2 - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_mutations_batcher_bytes_flush(client, table, temp_rows): - """ - batch should flush after flush_limit_bytes bytes - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value + @pytest.mark.asyncio + async def test_mutations_batcher_count_flush(self, client, table, temp_rows): + """ + batch should flush after flush_limit_mutation_count mutations + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + async with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + # should be noop; flush not scheduled + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # task should now be scheduled + assert len(batcher._flush_jobs) == 1 + await asyncio.gather(*batcher._flush_jobs) + assert len(batcher._staged_entries) == 0 + assert len(batcher._flush_jobs) == 0 + # ensure cells were updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + assert (await self._retrieve_cell_value(table, row_key2)) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = await _create_row_and_mutation( - table, temp_rows, new_value=new_value2 + @pytest.mark.asyncio + async def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): + """ + batch should flush after flush_limit_bytes bytes + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 + + async with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # task should now be scheduled + assert len(batcher._flush_jobs) == 1 + assert len(batcher._staged_entries) == 0 + # let flush complete + await asyncio.gather(*batcher._flush_jobs) + # ensure cells were updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + assert (await self._retrieve_cell_value(table, row_key2)) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.asyncio + async def test_mutations_batcher_no_flush(self, client, table, temp_rows): + """ + test with no flush requirements met + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + start_value = b"unchanged" + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 + async with table.mutations_batcher( + flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 + ) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # flush not scheduled + assert len(batcher._flush_jobs) == 0 + await asyncio.sleep(0.01) + assert len(batcher._staged_entries) == 2 + assert len(batcher._flush_jobs) == 0 + # ensure cells were not updated + assert (await self._retrieve_cell_value(table, row_key)) == start_value + assert (await self._retrieve_cell_value(table, row_key2)) == start_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,increment,expected", + [ + (0, 0, 0), + (0, 1, 1), + (0, -1, -1), + (1, 0, 1), + (0, -100, -100), + (0, 3000, 3000), + (10, 4, 14), + (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), + (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), + (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), + ], ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 + @pytest.mark.asyncio + async def test_read_modify_write_row_increment( + self, client, table, temp_rows, start, increment, expected + ): + """ + test read_modify_write_row + """ + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row( + row_key, value=start, family=family, qualifier=qualifier + ) - async with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: - await batcher.append(bulk_mutation) - assert len(batcher._flush_jobs) == 0 - assert len(batcher._staged_entries) == 1 - await batcher.append(bulk_mutation2) - # task should now be scheduled - assert len(batcher._flush_jobs) == 1 - assert len(batcher._staged_entries) == 0 - # let flush complete - await asyncio.gather(*batcher._flush_jobs) - # ensure cells were updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - assert (await _retrieve_cell_value(table, row_key2)) == new_value2 + rule = IncrementRule(family, qualifier, increment) + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert int(result[0]) == expected + # ensure that reading from server gives same value + assert (await self._retrieve_cell_value(table, row_key)) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,append,expected", + [ + (b"", b"", b""), + ("", "", b""), + (b"abc", b"123", b"abc123"), + (b"abc", "123", b"abc123"), + ("", b"1", b"1"), + (b"abc", "", b"abc"), + (b"hello", b"world", b"helloworld"), + ], + ) + @pytest.mark.asyncio + async def test_read_modify_write_row_append( + self, client, table, temp_rows, start, append, expected + ): + """ + test read_modify_write_row + """ + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row( + row_key, value=start, family=family, qualifier=qualifier + ) + rule = AppendValueRule(family, qualifier, append) + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert result[0].value == expected + # ensure that reading from server gives same value + assert (await self._retrieve_cell_value(table, row_key)) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.asyncio + async def test_read_modify_write_row_chained(self, client, table, temp_rows): + """ + test read_modify_write_row with multiple rules + """ + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + start_amount = 1 + increment_amount = 10 + await temp_rows.add_row( + row_key, value=start_amount, family=family, qualifier=qualifier + ) + rule = [ + IncrementRule(family, qualifier, increment_amount), + AppendValueRule(family, qualifier, "hello"), + AppendValueRule(family, qualifier, "world"), + AppendValueRule(family, qualifier, "!"), + ] + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert result[0].family == family + assert result[0].qualifier == qualifier + # result should be a bytes number string for the IncrementRules, followed by the AppendValueRule values + assert ( + result[0].value + == (start_amount + increment_amount).to_bytes(8, "big", signed=True) + + b"helloworld!" + ) + # ensure that reading from server gives same value + assert (await self._retrieve_cell_value(table, row_key)) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start_val,predicate_range,expected_result", + [ + (1, (0, 2), True), + (-1, (0, 2), False), + ], + ) + @pytest.mark.asyncio + async def test_check_and_mutate( + self, client, table, temp_rows, start_val, predicate_range, expected_result + ): + """ + test that check_and_mutate_row works applies the right mutations, and returns the right result + """ + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + + await temp_rows.add_row( + row_key, value=start_val, family=family, qualifier=qualifier + ) -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_mutations_batcher_no_flush(client, table, temp_rows): - """ - test with no flush requirements met - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry + false_mutation_value = b"false-mutation-value" + false_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value + ) + true_mutation_value = b"true-mutation-value" + true_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value + ) + predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) + result = await table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + assert result == expected_result + # ensure cell is updated + expected_value = ( + true_mutation_value if expected_result else false_mutation_value + ) + assert (await self._retrieve_cell_value(table, row_key)) == expected_value - new_value = uuid.uuid4().hex.encode() - start_value = b"unchanged" - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = await _create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.asyncio + async def test_check_and_mutate_empty_request(self, client, table): + """ + check_and_mutate with no true or fale mutations should raise an error + """ + from google.api_core import exceptions + + with pytest.raises(exceptions.InvalidArgument) as e: + await table.check_and_mutate_row( + b"row_key", None, true_case_mutations=None, false_case_mutations=None + ) + assert "No mutations provided" in str(e.value) + + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 - async with table.mutations_batcher( - flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 - ) as batcher: - await batcher.append(bulk_mutation) - assert len(batcher._staged_entries) == 1 - await batcher.append(bulk_mutation2) - # flush not scheduled - assert len(batcher._flush_jobs) == 0 - await asyncio.sleep(0.01) - assert len(batcher._staged_entries) == 2 - assert len(batcher._flush_jobs) == 0 - # ensure cells were not updated - assert (await _retrieve_cell_value(table, row_key)) == start_value - assert (await _retrieve_cell_value(table, row_key2)) == start_value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.parametrize( - "start,increment,expected", - [ - (0, 0, 0), - (0, 1, 1), - (0, -1, -1), - (1, 0, 1), - (0, -100, -100), - (0, 3000, 3000), - (10, 4, 14), - (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), - (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), - (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), - ], -) -@pytest.mark.asyncio -async def test_read_modify_write_row_increment( - client, table, temp_rows, start, increment, expected -): - """ - test read_modify_write_row - """ - from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) - - rule = IncrementRule(family, qualifier, increment) - result = await table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert len(result) == 1 - assert result[0].family == family - assert result[0].qualifier == qualifier - assert int(result[0]) == expected - # ensure that reading from server gives same value - assert (await _retrieve_cell_value(table, row_key)) == result[0].value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.parametrize( - "start,append,expected", - [ - (b"", b"", b""), - ("", "", b""), - (b"abc", b"123", b"abc123"), - (b"abc", "123", b"abc123"), - ("", b"1", b"1"), - (b"abc", "", b"abc"), - (b"hello", b"world", b"helloworld"), - ], -) -@pytest.mark.asyncio -async def test_read_modify_write_row_append( - client, table, temp_rows, start, append, expected -): - """ - test read_modify_write_row - """ - from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) - - rule = AppendValueRule(family, qualifier, append) - result = await table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert len(result) == 1 - assert result[0].family == family - assert result[0].qualifier == qualifier - assert result[0].value == expected - # ensure that reading from server gives same value - assert (await _retrieve_cell_value(table, row_key)) == result[0].value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_modify_write_row_chained(client, table, temp_rows): - """ - test read_modify_write_row with multiple rules - """ - from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule - from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - start_amount = 1 - increment_amount = 10 - await temp_rows.add_row( - row_key, value=start_amount, family=family, qualifier=qualifier + @pytest.mark.asyncio + async def test_read_rows_stream(self, table, temp_rows): + """ + Ensure that the read_rows_stream method works + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + + # full table scan + generator = await table.read_rows_stream({}) + first_row = await generator.__anext__() + second_row = await generator.__anext__() + assert first_row.row_key == b"row_key_1" + assert second_row.row_key == b"row_key_2" + with pytest.raises(StopAsyncIteration): + await generator.__anext__() + + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - rule = [ - IncrementRule(family, qualifier, increment_amount), - AppendValueRule(family, qualifier, "hello"), - AppendValueRule(family, qualifier, "world"), - AppendValueRule(family, qualifier, "!"), - ] - result = await table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert result[0].family == family - assert result[0].qualifier == qualifier - # result should be a bytes number string for the IncrementRules, followed by the AppendValueRule values - assert ( - result[0].value - == (start_amount + increment_amount).to_bytes(8, "big", signed=True) - + b"helloworld!" + @pytest.mark.asyncio + async def test_read_rows(self, table, temp_rows): + """ + Ensure that the read_rows method works + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + # full table scan + row_list = await table.read_rows({}) + assert len(row_list) == 2 + assert row_list[0].row_key == b"row_key_1" + assert row_list[1].row_key == b"row_key_2" + + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - # ensure that reading from server gives same value - assert (await _retrieve_cell_value(table, row_key)) == result[0].value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.parametrize( - "start_val,predicate_range,expected_result", - [ - (1, (0, 2), True), - (-1, (0, 2), False), - ], -) -@pytest.mark.asyncio -async def test_check_and_mutate( - client, table, temp_rows, start_val, predicate_range, expected_result -): - """ - test that check_and_mutate_row works applies the right mutations, and returns the right result - """ - from google.cloud.bigtable.data.mutations import SetCell - from google.cloud.bigtable.data.row_filters import ValueRangeFilter - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - - await temp_rows.add_row( - row_key, value=start_val, family=family, qualifier=qualifier + @pytest.mark.asyncio + async def test_read_rows_sharded_simple(self, table, temp_rows): + """ + Test read rows sharded with two queries + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) + row_list = await table.read_rows_sharded([query1, query2]) + assert len(row_list) == 4 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"b" + assert row_list[3].row_key == b"d" + + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - - false_mutation_value = b"false-mutation-value" - false_mutation = SetCell( - family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value + @pytest.mark.asyncio + async def test_read_rows_sharded_from_sample(self, table, temp_rows): + """ + Test end-to-end sharding + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.read_rows_query import RowRange + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + + table_shard_keys = await table.sample_row_keys() + query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) + shard_queries = query.shard(table_shard_keys) + row_list = await table.read_rows_sharded(shard_queries) + assert len(row_list) == 3 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"d" + + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - true_mutation_value = b"true-mutation-value" - true_mutation = SetCell( - family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value + @pytest.mark.asyncio + async def test_read_rows_sharded_filters_limits(self, table, temp_rows): + """ + Test read rows sharded with filters and limits + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + + label_filter1 = ApplyLabelFilter("first") + label_filter2 = ApplyLabelFilter("second") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) + row_list = await table.read_rows_sharded([query1, query2]) + assert len(row_list) == 3 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"b" + assert row_list[2].row_key == b"d" + assert row_list[0][0].labels == ["first"] + assert row_list[1][0].labels == ["second"] + assert row_list[2][0].labels == ["second"] + + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) - result = await table.check_and_mutate_row( - row_key, - predicate, - true_case_mutations=true_mutation, - false_case_mutations=false_mutation, + @pytest.mark.asyncio + async def test_read_rows_range_query(self, table, temp_rows): + """ + Ensure that the read_rows method works + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data import RowRange + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # full table scan + query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) + row_list = await table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - assert result == expected_result - # ensure cell is updated - expected_value = true_mutation_value if expected_result else false_mutation_value - assert (await _retrieve_cell_value(table, row_key)) == expected_value - - -@pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", -) -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_check_and_mutate_empty_request(client, table): - """ - check_and_mutate with no true or fale mutations should raise an error - """ - from google.api_core import exceptions - - with pytest.raises(exceptions.InvalidArgument) as e: - await table.check_and_mutate_row( - b"row_key", None, true_case_mutations=None, false_case_mutations=None - ) - assert "No mutations provided" in str(e.value) - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_stream(table, temp_rows): - """ - Ensure that the read_rows_stream method works - """ - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - - # full table scan - generator = await table.read_rows_stream({}) - first_row = await generator.__anext__() - second_row = await generator.__anext__() - assert first_row.row_key == b"row_key_1" - assert second_row.row_key == b"row_key_2" - with pytest.raises(StopAsyncIteration): - await generator.__anext__() - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows(table, temp_rows): - """ - Ensure that the read_rows method works - """ - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - # full table scan - row_list = await table.read_rows({}) - assert len(row_list) == 2 - assert row_list[0].row_key == b"row_key_1" - assert row_list[1].row_key == b"row_key_2" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_sharded_simple(table, temp_rows): - """ - Test read rows sharded with two queries - """ - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) - query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) - row_list = await table.read_rows_sharded([query1, query2]) - assert len(row_list) == 4 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"c" - assert row_list[2].row_key == b"b" - assert row_list[3].row_key == b"d" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_sharded_from_sample(table, temp_rows): - """ - Test end-to-end sharding - """ - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - from google.cloud.bigtable.data.read_rows_query import RowRange - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - - table_shard_keys = await table.sample_row_keys() - query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) - shard_queries = query.shard(table_shard_keys) - row_list = await table.read_rows_sharded(shard_queries) - assert len(row_list) == 3 - assert row_list[0].row_key == b"b" - assert row_list[1].row_key == b"c" - assert row_list[2].row_key == b"d" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_sharded_filters_limits(table, temp_rows): - """ - Test read rows sharded with filters and limits - """ - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - - label_filter1 = ApplyLabelFilter("first") - label_filter2 = ApplyLabelFilter("second") - query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) - query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) - row_list = await table.read_rows_sharded([query1, query2]) - assert len(row_list) == 3 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"b" - assert row_list[2].row_key == b"d" - assert row_list[0][0].labels == ["first"] - assert row_list[1][0].labels == ["second"] - assert row_list[2][0].labels == ["second"] - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_range_query(table, temp_rows): - """ - Ensure that the read_rows method works - """ - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable.data import RowRange - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - # full table scan - query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) - row_list = await table.read_rows(query) - assert len(row_list) == 2 - assert row_list[0].row_key == b"b" - assert row_list[1].row_key == b"c" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_single_key_query(table, temp_rows): - """ - Ensure that the read_rows method works with specified query - """ - from google.cloud.bigtable.data import ReadRowsQuery - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - # retrieve specific keys - query = ReadRowsQuery(row_keys=[b"a", b"c"]) - row_list = await table.read_rows(query) - assert len(row_list) == 2 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"c" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_with_filter(table, temp_rows): - """ - ensure filters are applied - """ - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - # retrieve keys with filter - expected_label = "test-label" - row_filter = ApplyLabelFilter(expected_label) - query = ReadRowsQuery(row_filter=row_filter) - row_list = await table.read_rows(query) - assert len(row_list) == 4 - for row in row_list: - assert row[0].labels == [expected_label] - - -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_rows_stream_close(table, temp_rows): - """ - Ensure that the read_rows_stream can be closed - """ - from google.cloud.bigtable.data import ReadRowsQuery - - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - # full table scan - query = ReadRowsQuery() - generator = await table.read_rows_stream(query) - # grab first row - first_row = await generator.__anext__() - assert first_row.row_key == b"row_key_1" - # close stream early - await generator.aclose() - with pytest.raises(StopAsyncIteration): - await generator.__anext__() - - -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_row(table, temp_rows): - """ - Test read_row (single row helper) - """ - from google.cloud.bigtable.data import Row - - await temp_rows.add_row(b"row_key_1", value=b"value") - row = await table.read_row(b"row_key_1") - assert isinstance(row, Row) - assert row.row_key == b"row_key_1" - assert row.cells[0].value == b"value" - - -@pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", -) -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_row_missing(table): - """ - Test read_row when row does not exist - """ - from google.api_core import exceptions - - row_key = "row_key_not_exist" - result = await table.read_row(row_key) - assert result is None - with pytest.raises(exceptions.InvalidArgument) as e: - await table.read_row("") - assert "Row keys must be non-empty" in str(e) - - -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_row_w_filter(table, temp_rows): - """ - Test read_row (single row helper) - """ - from google.cloud.bigtable.data import Row - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - await temp_rows.add_row(b"row_key_1", value=b"value") - expected_label = "test-label" - label_filter = ApplyLabelFilter(expected_label) - row = await table.read_row(b"row_key_1", row_filter=label_filter) - assert isinstance(row, Row) - assert row.row_key == b"row_key_1" - assert row.cells[0].value == b"value" - assert row.cells[0].labels == [expected_label] - - -@pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", -) -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_row_exists(table, temp_rows): - from google.api_core import exceptions - - """Test row_exists with rows that exist and don't exist""" - assert await table.row_exists(b"row_key_1") is False - await temp_rows.add_row(b"row_key_1") - assert await table.row_exists(b"row_key_1") is True - assert await table.row_exists("row_key_1") is True - assert await table.row_exists(b"row_key_2") is False - assert await table.row_exists("row_key_2") is False - assert await table.row_exists("3") is False - await temp_rows.add_row(b"3") - assert await table.row_exists(b"3") is True - with pytest.raises(exceptions.InvalidArgument) as e: - await table.row_exists("") - assert "Row keys must be non-empty" in str(e) - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.parametrize( - "cell_value,filter_input,expect_match", - [ - (b"abc", b"abc", True), - (b"abc", "abc", True), - (b".", ".", True), - (".*", ".*", True), - (".*", b".*", True), - ("a", ".*", False), - (b".*", b".*", True), - (r"\a", r"\a", True), - (b"\xe2\x98\x83", "☃", True), - ("☃", "☃", True), - (r"\C☃", r"\C☃", True), - (1, 1, True), - (2, 1, False), - (68, 68, True), - ("D", 68, False), - (68, "D", False), - (-1, -1, True), - (2852126720, 2852126720, True), - (-1431655766, -1431655766, True), - (-1431655766, -1, False), - ], -) -@pytest.mark.asyncio -async def test_literal_value_filter( - table, temp_rows, cell_value, filter_input, expect_match -): - """ - Literal value filter does complex escaping on re2 strings. - Make sure inputs are properly interpreted by the server - """ - from google.cloud.bigtable.data.row_filters import LiteralValueFilter - from google.cloud.bigtable.data import ReadRowsQuery - - f = LiteralValueFilter(filter_input) - await temp_rows.add_row(b"row_key_1", value=cell_value) - query = ReadRowsQuery(row_filter=f) - row_list = await table.read_rows(query) - assert len(row_list) == bool( - expect_match - ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" + @pytest.mark.asyncio + async def test_read_rows_single_key_query(self, table, temp_rows): + """ + Ensure that the read_rows method works with specified query + """ + from google.cloud.bigtable.data import ReadRowsQuery + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # retrieve specific keys + query = ReadRowsQuery(row_keys=[b"a", b"c"]) + row_list = await table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @pytest.mark.asyncio + async def test_read_rows_with_filter(self, table, temp_rows): + """ + ensure filters are applied + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # retrieve keys with filter + expected_label = "test-label" + row_filter = ApplyLabelFilter(expected_label) + query = ReadRowsQuery(row_filter=row_filter) + row_list = await table.read_rows(query) + assert len(row_list) == 4 + for row in row_list: + assert row[0].labels == [expected_label] + + @pytest.mark.usefixtures("table") + @pytest.mark.asyncio + async def test_read_rows_stream_close(self, table, temp_rows): + """ + Ensure that the read_rows_stream can be closed + """ + from google.cloud.bigtable.data import ReadRowsQuery + + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + # full table scan + query = ReadRowsQuery() + generator = await table.read_rows_stream(query) + # grab first row + first_row = await generator.__anext__() + assert first_row.row_key == b"row_key_1" + # close stream early + await generator.aclose() + with pytest.raises(StopAsyncIteration): + await generator.__anext__() + + @pytest.mark.usefixtures("table") + @pytest.mark.asyncio + async def test_read_row(self, table, temp_rows): + """ + Test read_row (single row helper) + """ + from google.cloud.bigtable.data import Row + + await temp_rows.add_row(b"row_key_1", value=b"value") + row = await table.read_row(b"row_key_1") + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + @pytest.mark.asyncio + async def test_read_row_missing(self, table): + """ + Test read_row when row does not exist + """ + from google.api_core import exceptions + + row_key = "row_key_not_exist" + result = await table.read_row(row_key) + assert result is None + with pytest.raises(exceptions.InvalidArgument) as e: + await table.read_row("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + @pytest.mark.asyncio + async def test_read_row_w_filter(self, table, temp_rows): + """ + Test read_row (single row helper) + """ + from google.cloud.bigtable.data import Row + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"row_key_1", value=b"value") + expected_label = "test-label" + label_filter = ApplyLabelFilter(expected_label) + row = await table.read_row(b"row_key_1", row_filter=label_filter) + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + assert row.cells[0].labels == [expected_label] + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + @pytest.mark.asyncio + async def test_row_exists(self, table, temp_rows): + from google.api_core import exceptions + + """Test row_exists with rows that exist and don't exist""" + assert await table.row_exists(b"row_key_1") is False + await temp_rows.add_row(b"row_key_1") + assert await table.row_exists(b"row_key_1") is True + assert await table.row_exists("row_key_1") is True + assert await table.row_exists(b"row_key_2") is False + assert await table.row_exists("row_key_2") is False + assert await table.row_exists("3") is False + await temp_rows.add_row(b"3") + assert await table.row_exists(b"3") is True + with pytest.raises(exceptions.InvalidArgument) as e: + await table.row_exists("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + @retry.AsyncRetry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @pytest.mark.parametrize( + "cell_value,filter_input,expect_match", + [ + (b"abc", b"abc", True), + (b"abc", "abc", True), + (b".", ".", True), + (".*", ".*", True), + (".*", b".*", True), + ("a", ".*", False), + (b".*", b".*", True), + (r"\a", r"\a", True), + (b"\xe2\x98\x83", "☃", True), + ("☃", "☃", True), + (r"\C☃", r"\C☃", True), + (1, 1, True), + (2, 1, False), + (68, 68, True), + ("D", 68, False), + (68, "D", False), + (-1, -1, True), + (2852126720, 2852126720, True), + (-1431655766, -1431655766, True), + (-1431655766, -1, False), + ], + ) + @pytest.mark.asyncio + async def test_literal_value_filter( + self, table, temp_rows, cell_value, filter_input, expect_match + ): + """ + Literal value filter does complex escaping on re2 strings. + Make sure inputs are properly interpreted by the server + """ + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + from google.cloud.bigtable.data import ReadRowsQuery + + f = LiteralValueFilter(filter_input) + await temp_rows.add_row(b"row_key_1", value=cell_value) + query = ReadRowsQuery(row_filter=f) + row_list = await table.read_rows(query) + assert len(row_list) == bool( + expect_match + ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" From 5cb4e827c280cab8c394cd6b3ef80e75360ec7aa Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 18 Apr 2024 10:57:45 -0700 Subject: [PATCH 050/360] fixed retry type --- .../bigtable_v2/services/bigtable/transports/grpc_asyncio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py b/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py index 7765ecce8..c57a2632c 100644 --- a/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py +++ b/google/cloud/bigtable_v2/services/bigtable/transports/grpc_asyncio.py @@ -529,7 +529,7 @@ def _prep_wrapped_messages(self, client_info): ), self.mutate_row: gapic_v1.method_async.wrap_method( self.mutate_row, - default_retry=retries.Retry( + default_retry=retries.AsyncRetry( initial=0.01, maximum=60.0, multiplier=2, From 46ef676cb8b3b2149d0db83a5f275c35cba59ec6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 18 Apr 2024 10:58:38 -0700 Subject: [PATCH 051/360] renamed test file --- tests/system/data/{test_system.py => test_system_async.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/system/data/{test_system.py => test_system_async.py} (100%) diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system_async.py similarity index 100% rename from tests/system/data/test_system.py rename to tests/system/data/test_system_async.py From 44840f329d100e773b86409b3745de1593c293f7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 18 Apr 2024 11:40:37 -0700 Subject: [PATCH 052/360] got system tests generated --- .../bigtable/data/_sync/system_tests.yaml | 24 + tests/system/data/setup_fixtures.py | 2 - tests/system/data/test_system.py | 783 ++++++++++++++++++ tests/system/data/test_system_async.py | 11 +- 4 files changed, 816 insertions(+), 4 deletions(-) create mode 100644 google/cloud/bigtable/data/_sync/system_tests.yaml create mode 100644 tests/system/data/test_system.py diff --git a/google/cloud/bigtable/data/_sync/system_tests.yaml b/google/cloud/bigtable/data/_sync/system_tests.yaml new file mode 100644 index 000000000..b6f3e64eb --- /dev/null +++ b/google/cloud/bigtable/data/_sync/system_tests.yaml @@ -0,0 +1,24 @@ +asyncio_replacements: # Replace asyncio functionaility + sleep: time.sleep + +added_imports: + - "from .test_system_async import TEST_FAMILY, TEST_FAMILY_2" + - "from google.cloud.bigtable.data import BigtableDataClient" + - "import time" + +text_replacements: + pytest_asyncio: pytest + AsyncRetry: Retry + BigtableDataClientAsync: BigtableDataClient + StopAsyncIteration: StopIteration + __anext__: __next__ + aclose: close + +classes: + - path: tests.system.data.test_system_async.TestSystemAsync + autogen_sync_name: TestSystemSync + drop_methods: ["event_loop"] + - path: tests.system.data.test_system_async.TempRowBuilder + autogen_sync_name: TempRowBuilder + +save_path: "tests/system/data/test_system.py" diff --git a/tests/system/data/setup_fixtures.py b/tests/system/data/setup_fixtures.py index 11013938b..3b5a0af06 100644 --- a/tests/system/data/setup_fixtures.py +++ b/tests/system/data/setup_fixtures.py @@ -17,9 +17,7 @@ """ import pytest -import pytest_asyncio import os -import asyncio import uuid diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py new file mode 100644 index 000000000..e8ad50a74 --- /dev/null +++ b/tests/system/data/test_system.py @@ -0,0 +1,783 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by sync_surface_generator.py. Do not edit. + + +from .test_system_async import TEST_FAMILY, TEST_FAMILY_2 +from abc import ABC +from tests.system.data.test_system_async import TempRowBuilder +import os +import pytest +import time +import uuid + +from google.api_core import retry +from google.api_core.exceptions import ClientError +from google.cloud.bigtable.data import BigtableDataClient +from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE +from google.cloud.environment_vars import BIGTABLE_EMULATOR + + +class TestSystemSync(ABC): + @pytest.fixture(scope="session") + def client(self): + project = os.getenv("GOOGLE_CLOUD_PROJECT") or None + with BigtableDataClient(project=project, pool_size=4) as client: + yield client + + @pytest.fixture(scope="session") + def table(self, client, table_id, instance_id): + with client.get_table(instance_id, table_id) as table: + yield table + + @pytest.fixture(scope="session") + def column_family_config(self): + """specify column families to create when creating a new test table""" + from google.cloud.bigtable_admin_v2 import types + + return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} + + @pytest.fixture(scope="session") + def init_table_id(self): + """The table_id to use when creating a new test table""" + return f"test-table-{uuid.uuid4().hex}" + + @pytest.fixture(scope="session") + def cluster_config(self, project_id): + """Configuration for the clusters to use when creating a new instance""" + from google.cloud.bigtable_admin_v2 import types + + cluster = { + "test-cluster": types.Cluster( + location=f"projects/{project_id}/locations/us-central1-b", serve_nodes=1 + ) + } + return cluster + + @pytest.mark.usefixtures("table") + def _retrieve_cell_value(self, table, row_key): + """Helper to read an individual row""" + from google.cloud.bigtable.data import ReadRowsQuery + + row_list = table.read_rows(ReadRowsQuery(row_keys=row_key)) + assert len(row_list) == 1 + row = row_list[0] + cell = row.cells[0] + return cell.value + + def _create_row_and_mutation( + self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" + ): + """Helper to create a new row, and a sample set_cell mutation to change its value""" + from google.cloud.bigtable.data.mutations import SetCell + + row_key = uuid.uuid4().hex.encode() + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row( + row_key, family=family, qualifier=qualifier, value=start_value + ) + assert self._retrieve_cell_value(table, row_key) == start_value + mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) + return (row_key, mutation) + + @pytest.fixture(scope="function") + def temp_rows(self, table): + builder = TempRowBuilder(table) + yield builder + builder.delete_rows() + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=10) + def test_ping_and_warm_gapic(self, client, table): + """ + Simple ping rpc test + This test ensures channels are able to authenticate with backend + """ + request = {"name": table.instance_name} + client._gapic_client.ping_and_warm(request) + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_ping_and_warm(self, client, table): + """Test ping and warm from handwritten client""" + try: + channel = client.transport._grpc_channel.pool[0] + except Exception: + channel = client.transport._grpc_channel + results = client._ping_and_warm_instances(channel) + assert len(results) == 1 + assert results[0] is None + + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_mutation_set_cell(self, table, temp_rows): + """Ensure cells can be set properly""" + row_key = b"bulk_mutate" + new_value = uuid.uuid4().hex.encode() + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + table.mutate_row(row_key, mutation) + assert self._retrieve_cell_value(table, row_key) == new_value + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" + ) + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_sample_row_keys(self, client, table, temp_rows, column_split_config): + """Sample keys should return a single sample in small test tables""" + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + results = table.sample_row_keys() + assert len(results) == len(column_split_config) + 1 + for idx in range(len(column_split_config)): + assert results[idx][0] == column_split_config[idx] + assert isinstance(results[idx][1], int) + assert results[-1][0] == b"" + assert isinstance(results[-1][1], int) + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_bulk_mutations_set_cell(self, client, table, temp_rows): + """Ensure cells can be set properly""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + table.bulk_mutate_rows([bulk_mutation]) + assert self._retrieve_cell_value(table, row_key) == new_value + + def test_bulk_mutations_raise_exception(self, client, table): + """If an invalid mutation is passed, an exception should be raised""" + from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + row_key = uuid.uuid4().hex.encode() + mutation = SetCell( + family="nonexistent", qualifier=b"test-qualifier", new_value=b"" + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + with pytest.raises(MutationsExceptionGroup) as exc: + table.bulk_mutate_rows([bulk_mutation]) + assert len(exc.value.exceptions) == 1 + entry_error = exc.value.exceptions[0] + assert isinstance(entry_error, FailedMutationEntryError) + assert entry_error.index == 0 + assert entry_error.entry == bulk_mutation + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_mutations_batcher_context_manager(self, client, table, temp_rows): + """test batcher with context manager. Should flush on exit""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + (row_key2, mutation2) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + with table.mutations_batcher() as batcher: + batcher.append(bulk_mutation) + batcher.append(bulk_mutation2) + assert self._retrieve_cell_value(table, row_key) == new_value + assert len(batcher._staged_entries) == 0 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_mutations_batcher_timer_flush(self, client, table, temp_rows): + """batch should occur after flush_interval seconds""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + flush_interval = 0.1 + with table.mutations_batcher(flush_interval=flush_interval) as batcher: + batcher.append(bulk_mutation) + time.sleep(0) + assert len(batcher._staged_entries) == 1 + time.sleep(flush_interval + 0.1) + assert len(batcher._staged_entries) == 0 + assert self._retrieve_cell_value(table, row_key) == new_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_mutations_batcher_count_flush(self, client, table, temp_rows): + """batch should flush after flush_limit_mutation_count mutations""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + (row_key2, mutation2) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: + batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + assert len(batcher._staged_entries) == 1 + batcher.append(bulk_mutation2) + assert len(batcher._flush_jobs) == 1 + for future in list(batcher._flush_jobs): + future + future.result() + assert len(batcher._staged_entries) == 0 + assert len(batcher._flush_jobs) == 0 + assert self._retrieve_cell_value(table, row_key) == new_value + assert self._retrieve_cell_value(table, row_key2) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): + """batch should flush after flush_limit_bytes bytes""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + (row_key2, mutation2) = self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 + with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: + batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + assert len(batcher._staged_entries) == 1 + batcher.append(bulk_mutation2) + assert len(batcher._flush_jobs) == 1 + assert len(batcher._staged_entries) == 0 + for future in list(batcher._flush_jobs): + future + future.result() + assert self._retrieve_cell_value(table, row_key) == new_value + assert self._retrieve_cell_value(table, row_key2) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_mutations_batcher_no_flush(self, client, table, temp_rows): + """test with no flush requirements met""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + start_value = b"unchanged" + (row_key, mutation) = self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + (row_key2, mutation2) = self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 + with table.mutations_batcher( + flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 + ) as batcher: + batcher.append(bulk_mutation) + assert len(batcher._staged_entries) == 1 + batcher.append(bulk_mutation2) + assert len(batcher._flush_jobs) == 0 + time.sleep(0.01) + assert len(batcher._staged_entries) == 2 + assert len(batcher._flush_jobs) == 0 + assert self._retrieve_cell_value(table, row_key) == start_value + assert self._retrieve_cell_value(table, row_key2) == start_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,increment,expected", + [ + (0, 0, 0), + (0, 1, 1), + (0, -1, -1), + (1, 0, 1), + (0, -100, -100), + (0, 3000, 3000), + (10, 4, 14), + (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), + (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), + (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), + ], + ) + def test_read_modify_write_row_increment( + self, client, table, temp_rows, start, increment, expected + ): + """test read_modify_write_row""" + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + rule = IncrementRule(family, qualifier, increment) + result = table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert int(result[0]) == expected + assert self._retrieve_cell_value(table, row_key) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,append,expected", + [ + (b"", b"", b""), + ("", "", b""), + (b"abc", b"123", b"abc123"), + (b"abc", "123", b"abc123"), + ("", b"1", b"1"), + (b"abc", "", b"abc"), + (b"hello", b"world", b"helloworld"), + ], + ) + def test_read_modify_write_row_append( + self, client, table, temp_rows, start, append, expected + ): + """test read_modify_write_row""" + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + rule = AppendValueRule(family, qualifier, append) + result = table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert result[0].value == expected + assert self._retrieve_cell_value(table, row_key) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_read_modify_write_row_chained(self, client, table, temp_rows): + """test read_modify_write_row with multiple rules""" + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + start_amount = 1 + increment_amount = 10 + temp_rows.add_row( + row_key, value=start_amount, family=family, qualifier=qualifier + ) + rule = [ + IncrementRule(family, qualifier, increment_amount), + AppendValueRule(family, qualifier, "hello"), + AppendValueRule(family, qualifier, "world"), + AppendValueRule(family, qualifier, "!"), + ] + result = table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert result[0].family == family + assert result[0].qualifier == qualifier + assert ( + result[0].value + == (start_amount + increment_amount).to_bytes(8, "big", signed=True) + + b"helloworld!" + ) + assert self._retrieve_cell_value(table, row_key) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start_val,predicate_range,expected_result", + [(1, (0, 2), True), (-1, (0, 2), False)], + ) + def test_check_and_mutate( + self, client, table, temp_rows, start_val, predicate_range, expected_result + ): + """test that check_and_mutate_row works applies the right mutations, and returns the right result""" + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row(row_key, value=start_val, family=family, qualifier=qualifier) + false_mutation_value = b"false-mutation-value" + false_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value + ) + true_mutation_value = b"true-mutation-value" + true_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value + ) + predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) + result = table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + assert result == expected_result + expected_value = ( + true_mutation_value if expected_result else false_mutation_value + ) + assert self._retrieve_cell_value(table, row_key) == expected_value + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_check_and_mutate_empty_request(self, client, table): + """check_and_mutate with no true or fale mutations should raise an error""" + from google.api_core import exceptions + + with pytest.raises(exceptions.InvalidArgument) as e: + table.check_and_mutate_row( + b"row_key", None, true_case_mutations=None, false_case_mutations=None + ) + assert "No mutations provided" in str(e.value) + + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_read_rows_stream(self, table, temp_rows): + """Ensure that the read_rows_stream method works""" + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + generator = table.read_rows_stream({}) + first_row = generator.__next__() + second_row = generator.__next__() + assert first_row.row_key == b"row_key_1" + assert second_row.row_key == b"row_key_2" + with pytest.raises(StopIteration): + generator.__next__() + + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_read_rows(self, table, temp_rows): + """Ensure that the read_rows method works""" + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + row_list = table.read_rows({}) + assert len(row_list) == 2 + assert row_list[0].row_key == b"row_key_1" + assert row_list[1].row_key == b"row_key_2" + + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_read_rows_sharded_simple(self, table, temp_rows): + """Test read rows sharded with two queries""" + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) + row_list = table.read_rows_sharded([query1, query2]) + assert len(row_list) == 4 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"b" + assert row_list[3].row_key == b"d" + + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_read_rows_sharded_from_sample(self, table, temp_rows): + """Test end-to-end sharding""" + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.read_rows_query import RowRange + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + table_shard_keys = table.sample_row_keys() + query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) + shard_queries = query.shard(table_shard_keys) + row_list = table.read_rows_sharded(shard_queries) + assert len(row_list) == 3 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"d" + + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_read_rows_sharded_filters_limits(self, table, temp_rows): + """Test read rows sharded with filters and limits""" + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + label_filter1 = ApplyLabelFilter("first") + label_filter2 = ApplyLabelFilter("second") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) + row_list = table.read_rows_sharded([query1, query2]) + assert len(row_list) == 3 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"b" + assert row_list[2].row_key == b"d" + assert row_list[0][0].labels == ["first"] + assert row_list[1][0].labels == ["second"] + assert row_list[2][0].labels == ["second"] + + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_read_rows_range_query(self, table, temp_rows): + """Ensure that the read_rows method works""" + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data import RowRange + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) + row_list = table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_read_rows_single_key_query(self, table, temp_rows): + """Ensure that the read_rows method works with specified query""" + from google.cloud.bigtable.data import ReadRowsQuery + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + query = ReadRowsQuery(row_keys=[b"a", b"c"]) + row_list = table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + def test_read_rows_with_filter(self, table, temp_rows): + """ensure filters are applied""" + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + expected_label = "test-label" + row_filter = ApplyLabelFilter(expected_label) + query = ReadRowsQuery(row_filter=row_filter) + row_list = table.read_rows(query) + assert len(row_list) == 4 + for row in row_list: + assert row[0].labels == [expected_label] + + @pytest.mark.usefixtures("table") + def test_read_rows_stream_close(self, table, temp_rows): + """Ensure that the read_rows_stream can be closed""" + from google.cloud.bigtable.data import ReadRowsQuery + + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + query = ReadRowsQuery() + generator = table.read_rows_stream(query) + first_row = generator.__next__() + assert first_row.row_key == b"row_key_1" + generator.close() + with pytest.raises(StopIteration): + generator.__next__() + + @pytest.mark.usefixtures("table") + def test_read_row(self, table, temp_rows): + """Test read_row (single row helper)""" + from google.cloud.bigtable.data import Row + + temp_rows.add_row(b"row_key_1", value=b"value") + row = table.read_row(b"row_key_1") + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + def test_read_row_missing(self, table): + """Test read_row when row does not exist""" + from google.api_core import exceptions + + row_key = "row_key_not_exist" + result = table.read_row(row_key) + assert result is None + with pytest.raises(exceptions.InvalidArgument) as e: + table.read_row("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + def test_read_row_w_filter(self, table, temp_rows): + """Test read_row (single row helper)""" + from google.cloud.bigtable.data import Row + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + temp_rows.add_row(b"row_key_1", value=b"value") + expected_label = "test-label" + label_filter = ApplyLabelFilter(expected_label) + row = table.read_row(b"row_key_1", row_filter=label_filter) + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + assert row.cells[0].labels == [expected_label] + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + def test_row_exists(self, table, temp_rows): + from google.api_core import exceptions + + "Test row_exists with rows that exist and don't exist" + assert table.row_exists(b"row_key_1") is False + temp_rows.add_row(b"row_key_1") + assert table.row_exists(b"row_key_1") is True + assert table.row_exists("row_key_1") is True + assert table.row_exists(b"row_key_2") is False + assert table.row_exists("row_key_2") is False + assert table.row_exists("3") is False + temp_rows.add_row(b"3") + assert table.row_exists(b"3") is True + with pytest.raises(exceptions.InvalidArgument) as e: + table.row_exists("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) + @pytest.mark.parametrize( + "cell_value,filter_input,expect_match", + [ + (b"abc", b"abc", True), + (b"abc", "abc", True), + (b".", ".", True), + (".*", ".*", True), + (".*", b".*", True), + ("a", ".*", False), + (b".*", b".*", True), + ("\\a", "\\a", True), + (b"\xe2\x98\x83", "☃", True), + ("☃", "☃", True), + ("\\C☃", "\\C☃", True), + (1, 1, True), + (2, 1, False), + (68, 68, True), + ("D", 68, False), + (68, "D", False), + (-1, -1, True), + (2852126720, 2852126720, True), + (-1431655766, -1431655766, True), + (-1431655766, -1, False), + ], + ) + def test_literal_value_filter( + self, table, temp_rows, cell_value, filter_input, expect_match + ): + """ + Literal value filter does complex escaping on re2 strings. + Make sure inputs are properly interpreted by the server + """ + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + from google.cloud.bigtable.data import ReadRowsQuery + + f = LiteralValueFilter(filter_input) + temp_rows.add_row(b"row_key_1", value=cell_value) + query = ReadRowsQuery(row_filter=f) + row_list = table.read_rows(query) + assert len(row_list) == bool( + expect_match + ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" + + +class TempRowBuilder(ABC): + """ + Used to add rows to a table for testing purposes. + """ + + def __init__(self, table): + self.rows = [] + self.table = table + + def add_row( + self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" + ): + if isinstance(value, str): + value = value.encode("utf-8") + elif isinstance(value, int): + value = value.to_bytes(8, byteorder="big", signed=True) + request = { + "table_name": self.table.table_name, + "row_key": row_key, + "mutations": [ + { + "set_cell": { + "family_name": family, + "column_qualifier": qualifier, + "value": value, + } + } + ], + } + self.table.client._gapic_client.mutate_row(request) + self.rows.append(row_key) + + def delete_rows(self): + if self.rows: + request = { + "table_name": self.table.table_name, + "entries": [ + {"row_key": row, "mutations": [{"delete_from_row": {}}]} + for row in self.rows + ], + } + self.table.client._gapic_client.mutate_rows(request) diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 3a229fc28..8d7c86ea4 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -372,7 +372,11 @@ async def test_mutations_batcher_count_flush(self, client, table, temp_rows): await batcher.append(bulk_mutation2) # task should now be scheduled assert len(batcher._flush_jobs) == 1 - await asyncio.gather(*batcher._flush_jobs) + # let flush complete + for future in list(batcher._flush_jobs): + await future + # for sync version: grab result + future.result() assert len(batcher._staged_entries) == 0 assert len(batcher._flush_jobs) == 0 # ensure cells were updated @@ -412,7 +416,10 @@ async def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): assert len(batcher._flush_jobs) == 1 assert len(batcher._staged_entries) == 0 # let flush complete - await asyncio.gather(*batcher._flush_jobs) + for future in list(batcher._flush_jobs): + await future + # for sync version: grab result + future.result() # ensure cells were updated assert (await self._retrieve_cell_value(table, row_key)) == new_value assert (await self._retrieve_cell_value(table, row_key2)) == new_value2 From 3c339543ace6d9b051baf624306ffffa427dae3c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 18 Apr 2024 11:40:52 -0700 Subject: [PATCH 053/360] fixed possible flake --- google/cloud/bigtable/data/_sync/mutations_batcher.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 5a60e0831..ee5a3aac7 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -51,7 +51,8 @@ def close(self): # attempt cancel timer if not started self._flush_timer.cancel() self._schedule_flush() - self._executor.shutdown(wait=True) + with self._executor: + self._executor.shutdown(wait=True) atexit.unregister(self._on_exit) # raise unreported exceptions self._raise_exceptions() From 7c9c3bdf40277b2af9805f59fa449cedfa839223 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 18 Apr 2024 11:41:52 -0700 Subject: [PATCH 054/360] changed class order to match async --- .../bigtable/data/_sync/system_tests.yaml | 4 +- tests/system/data/test_system.py | 88 +++++++++---------- 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/system_tests.yaml b/google/cloud/bigtable/data/_sync/system_tests.yaml index b6f3e64eb..d8a3c7c53 100644 --- a/google/cloud/bigtable/data/_sync/system_tests.yaml +++ b/google/cloud/bigtable/data/_sync/system_tests.yaml @@ -15,10 +15,10 @@ text_replacements: aclose: close classes: + - path: tests.system.data.test_system_async.TempRowBuilder + autogen_sync_name: TempRowBuilder - path: tests.system.data.test_system_async.TestSystemAsync autogen_sync_name: TestSystemSync drop_methods: ["event_loop"] - - path: tests.system.data.test_system_async.TempRowBuilder - autogen_sync_name: TempRowBuilder save_path: "tests/system/data/test_system.py" diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py index e8ad50a74..3ed20f1e1 100644 --- a/tests/system/data/test_system.py +++ b/tests/system/data/test_system.py @@ -30,6 +30,50 @@ from google.cloud.environment_vars import BIGTABLE_EMULATOR +class TempRowBuilder(ABC): + """ + Used to add rows to a table for testing purposes. + """ + + def __init__(self, table): + self.rows = [] + self.table = table + + def add_row( + self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" + ): + if isinstance(value, str): + value = value.encode("utf-8") + elif isinstance(value, int): + value = value.to_bytes(8, byteorder="big", signed=True) + request = { + "table_name": self.table.table_name, + "row_key": row_key, + "mutations": [ + { + "set_cell": { + "family_name": family, + "column_qualifier": qualifier, + "value": value, + } + } + ], + } + self.table.client._gapic_client.mutate_row(request) + self.rows.append(row_key) + + def delete_rows(self): + if self.rows: + request = { + "table_name": self.table.table_name, + "entries": [ + {"row_key": row, "mutations": [{"delete_from_row": {}}]} + for row in self.rows + ], + } + self.table.client._gapic_client.mutate_rows(request) + + class TestSystemSync(ABC): @pytest.fixture(scope="session") def client(self): @@ -737,47 +781,3 @@ def test_literal_value_filter( assert len(row_list) == bool( expect_match ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" - - -class TempRowBuilder(ABC): - """ - Used to add rows to a table for testing purposes. - """ - - def __init__(self, table): - self.rows = [] - self.table = table - - def add_row( - self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" - ): - if isinstance(value, str): - value = value.encode("utf-8") - elif isinstance(value, int): - value = value.to_bytes(8, byteorder="big", signed=True) - request = { - "table_name": self.table.table_name, - "row_key": row_key, - "mutations": [ - { - "set_cell": { - "family_name": family, - "column_qualifier": qualifier, - "value": value, - } - } - ], - } - self.table.client._gapic_client.mutate_row(request) - self.rows.append(row_key) - - def delete_rows(self): - if self.rows: - request = { - "table_name": self.table.table_name, - "entries": [ - {"row_key": row, "mutations": [{"delete_from_row": {}}]} - for row in self.rows - ], - } - self.table.client._gapic_client.mutate_rows(request) From 7661fec29f2d391bc9e9c51275eaca87adbb44e3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 18 Apr 2024 13:00:08 -0700 Subject: [PATCH 055/360] renamed class --- google/cloud/bigtable/data/_sync/system_tests.yaml | 3 ++- tests/system/data/test_system.py | 1 - tests/system/data/test_system_async.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/system_tests.yaml b/google/cloud/bigtable/data/_sync/system_tests.yaml index d8a3c7c53..43c78aa9e 100644 --- a/google/cloud/bigtable/data/_sync/system_tests.yaml +++ b/google/cloud/bigtable/data/_sync/system_tests.yaml @@ -10,12 +10,13 @@ text_replacements: pytest_asyncio: pytest AsyncRetry: Retry BigtableDataClientAsync: BigtableDataClient + TempRowBuilderAsync: TempRowBuilder StopAsyncIteration: StopIteration __anext__: __next__ aclose: close classes: - - path: tests.system.data.test_system_async.TempRowBuilder + - path: tests.system.data.test_system_async.TempRowBuilderAsync autogen_sync_name: TempRowBuilder - path: tests.system.data.test_system_async.TestSystemAsync autogen_sync_name: TestSystemSync diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py index 3ed20f1e1..33710d808 100644 --- a/tests/system/data/test_system.py +++ b/tests/system/data/test_system.py @@ -17,7 +17,6 @@ from .test_system_async import TEST_FAMILY, TEST_FAMILY_2 from abc import ABC -from tests.system.data.test_system_async import TempRowBuilder import os import pytest import time diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 8d7c86ea4..69a23412e 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -27,7 +27,7 @@ TEST_FAMILY_2 = "test-family-2" -class TempRowBuilder: +class TempRowBuilderAsync: """ Used to add rows to a table for testing purposes. """ @@ -161,7 +161,7 @@ async def _create_row_and_mutation( @pytest_asyncio.fixture(scope="function") async def temp_rows(self, table): - builder = TempRowBuilder(table) + builder = TempRowBuilderAsync(table) yield builder await builder.delete_rows() From 8e2db64ec3b87f282aa25968021957f4c10125ae Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 19 Apr 2024 11:38:08 -0700 Subject: [PATCH 056/360] experiment with CrossSync annotations --- google/cloud/bigtable/data/_async/client.py | 104 +++++++----------- google/cloud/bigtable/data/_sync/client.py | 3 +- .../cloud/bigtable/data/_sync/cross_sync.py | 59 ++++++++++ 3 files changed, 101 insertions(+), 65 deletions(-) create mode 100644 google/cloud/bigtable/data/_sync/cross_sync.py diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index aaed11cca..3c0716ca7 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -31,6 +31,7 @@ import sys import random import os +import concurrent.futures from functools import partial from grpc import Channel @@ -74,13 +75,21 @@ from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter from google.cloud.bigtable.data.row_filters import RowFilterChain +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery +if CrossSync.SyncImports: + from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport, PooledChannel + +@CrossSync.sync_output("google.cloud.bigtable._sync._autogen.BigtableDataClient") class BigtableDataClientAsync(ClientWithProject): + def __init__( self, *, @@ -162,6 +171,7 @@ def __init__( self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() self._channel_refresh_tasks: list[asyncio.Task[None]] = [] + self._executor = concurrent.futures.ThreadPoolExecutor() if not CrossSync.is_async else None if self._emulator_host is not None: # connect to an emulator host warnings.warn( @@ -194,7 +204,10 @@ def _client_version() -> str: """ Helper function to return the client version string for this client """ - return f"{google.cloud.bigtable.__version__}-data-async" + if CrossSync.is_async: + return f"{google.cloud.bigtable.__version__}-data-async" + else: + return f"{google.cloud.bigtable.__version__}-data" def _start_background_channel_refresh(self) -> None: """ @@ -207,28 +220,26 @@ def _start_background_channel_refresh(self) -> None: and not self._emulator_host and not self._is_closed.is_set() ): - # raise RuntimeError if there is no event loop - asyncio.get_running_loop() for channel_idx in range(self.transport.pool_size): - refresh_task = asyncio.create_task(self._manage_channel(channel_idx)) - if sys.version_info >= (3, 8): - # task names supported in Python 3.8+ - refresh_task.set_name( - f"{self.__class__.__name__} channel refresh {channel_idx}" - ) + refresh_task = CrossSync.create_task( + self._manage_channel, channel_idx, + sync_executor=self._executor, + task_name=f"{self.__class__.__name__} channel refresh {channel_idx}" + ) self._channel_refresh_tasks.append(refresh_task) + refresh_task.add_done_callback(lambda _: self._channel_refresh_tasks.remove(refresh_task)) - async def close(self, timeout: float = 2.0): + async def close(self, timeout: float | None = None): """ Cancel all background tasks """ self._is_closed.set() for task in self._channel_refresh_tasks: task.cancel() - group = asyncio.gather(*self._channel_refresh_tasks, return_exceptions=True) - await asyncio.wait_for(group, timeout=timeout) await self.transport.close() - self._channel_refresh_tasks = [] + if self._executor: + self._executor.shutdown(wait=False) + await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) async def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None @@ -266,23 +277,7 @@ async def _ping_and_warm_instances( ) for (instance_name, table_name, app_profile_id) in instance_list ] - return await self._execute_ping_and_warms(*partial_list) - - async def _execute_ping_and_warms(self, *fns) -> list[BaseException | None]: - """ - Execute batch of ping and warm requests in parallel - - Will have separate implementation for sync and async clients - - Args: - - fns: list of partial functions to execute ping and warm requests - Returns: - - list of results or exceptions from the ping requests - """ - # extract coroutine out of partials - coro_list = [fn() for fn in fns] - result_list = await asyncio.gather(*coro_list, return_exceptions=True) - # return None in place of empty successful responses + result_list = await CrossSync.gather_partials(partial_list, return_exceptions=True, sync_executor=self._executor) return [r or None for r in result_list] async def _manage_channel( @@ -311,6 +306,7 @@ async def _manage_channel( grace_period: time to allow previous channel to serve existing requests before closing, in seconds """ + sleep_fn = asyncio.sleep if CrossSync.is_async else self._is_closed.wait first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) @@ -321,7 +317,7 @@ async def _manage_channel( await self._ping_and_warm_instances(channel) # continuously refresh the channel every `refresh_interval` seconds while not self._is_closed.is_set(): - await asyncio.sleep(next_sleep) + await sleep_fn(next_sleep) if self._is_closed.is_set(): break # prepare new channel for use @@ -561,20 +557,10 @@ def __init__( default_mutate_rows_retryable_errors or () ) self.default_retryable_errors = default_retryable_errors or () - self._register_instance_future: asyncio.Future[ - None - ] = self._register_with_client() - - def _register_with_client(self) -> asyncio.Future[None]: - """ - Calls the client's _register_instance method to warm the grpc channels for this instance - - Different implementations for sync vs async client - """ - # raises RuntimeError if called outside of an async context (no running event loop) try: - return asyncio.create_task( - self.client._register_instance(self.instance_id, self) + self._register_instance_future = CrossSync.create_task( + self.client._register_instance, self.instance_id, self, + sync_executor=self.client._executor, ) except RuntimeError as e: raise RuntimeError( @@ -797,16 +783,19 @@ async def read_rows_sharded( shard_idx = 0 for batch in batched_queries: batch_operation_timeout = next(timeout_generator) - batch_kwargs_list = [ - { - "query": query, - "operation_timeout": batch_operation_timeout, - "attempt_timeout": min(attempt_timeout, batch_operation_timeout), - "retryable_errors": retryable_errors, - } + batch_partial_list = [ + partial( + self.read_rows, + query=query, + operation_timeout=batch_operation_timeout, + attempt_timeout=min(attempt_timeout, batch_operation_timeout), + retryable_errors=retryable_errors, + ) for query in batch ] - batch_result = await self._shard_batch_helper(batch_kwargs_list) + batch_result = await CrossSync.gather_partials( + batch_partial_list, return_exceptions=True, sync_executor=self.client._executor + ) for result in batch_result: if isinstance(result, Exception): error_dict[shard_idx] = result @@ -828,17 +817,6 @@ async def read_rows_sharded( ) return results_list - async def _shard_batch_helper( - self, kwargs_list: list[dict] - ) -> list[list[Row] | BaseException]: - """ - Helper function for executing a batch of read_rows queries in parallel - - Sync client implementation will override this method - """ - routine_list = [self.read_rows(**kwargs) for kwargs in kwargs_list] - return await asyncio.gather(*routine_list, return_exceptions=True) - async def row_exists( self, row_key: str | bytes, diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index c591b91c0..b577f37e1 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -70,8 +70,7 @@ def close(self) -> None: This method should be called when the client is no longer needed. """ self._is_closed.set() - with self._executor: - self._executor.shutdown(wait=False) + self._executor.shutdown(wait=True) self._channel_refresh_tasks = [] self.transport.close() diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py new file mode 100644 index 000000000..dbe8f1c1c --- /dev/null +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -0,0 +1,59 @@ +import asyncio +import sys + + +class CrossSync: + + SyncImports = False + is_async = True + + sleep = asyncio.sleep + Queue = asyncio.Queue + Condition = asyncio.Condition + Future = asyncio.Future + Task = asyncio.Task + Event = asyncio.Event + + @classmethod + def sync_output(cls, sync_path): + # return the async class unchanged + return lambda async_cls: async_cls + + @staticmethod + async def gather_partials(partial_list, return_exceptions=False, sync_executor=None): + """ + abstraction over asyncio.gather + + In the async version, the partials are expected to return an awaitable object. Patials + are unpacked and awaited in the gather call. + + Sync version implemented with threadpool executor + + Returns: + - a list of results (or exceptions, if return_exceptions=True) in the same order as partial_list + """ + if not partial_list: + return [] + awaitable_list = [partial() for partial in partial_list] + return await asyncio.gather(*awaitable_list, return_exceptions=return_exceptions) + + @staticmethod + async def wait(futures, timeout=None): + """ + abstraction over asyncio.wait + """ + if not futures: + return set(), set() + return await asyncio.wait(futures, timeout=timeout) + + @staticmethod + def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): + """ + abstraction over asyncio.create_task. Sync version implemented with threadpool executor + + sync_executor: ThreadPoolExecutor to use for sync operations. Ignored in async version + """ + task = asyncio.create_task(fn(*fn_args, **fn_kwargs)) + if task_name and sys.version_info >= (3, 8): + task.set_name(task_name) + return task From bbe50056abea54ad5b35ab66c23fa41d76f45111 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 19 Apr 2024 12:20:03 -0700 Subject: [PATCH 057/360] hacked together sync generator to work with annoptations --- google/cloud/bigtable/data/_async/client.py | 2 +- google/cloud/bigtable/data/_sync/_autogen.py | 1721 +---------------- .../cloud/bigtable/data/_sync/cross_sync.py | 8 +- sync_surface_generator.py | 135 +- 4 files changed, 183 insertions(+), 1683 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 3c0716ca7..02e0ce05e 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -87,7 +87,7 @@ from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport, PooledChannel -@CrossSync.sync_output("google.cloud.bigtable._sync._autogen.BigtableDataClient") +@CrossSync.sync_output("google.cloud.bigtable.data._sync._autogen.BigtableDataClient_SyncGen") class BigtableDataClientAsync(ClientWithProject): def __init__( diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py index a0bb005b7..e537b3ac6 100644 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ b/google/cloud/bigtable/data/_sync/_autogen.py @@ -17,850 +17,43 @@ from __future__ import annotations from abc import ABC -from collections import deque from functools import partial from grpc import Channel from typing import Any -from typing import Iterable from typing import Optional -from typing import Sequence from typing import Set from typing import cast -import atexit +import asyncio import concurrent.futures -import functools import os import random -import threading import time import warnings from google.api_core import client_options as client_options_lib -from google.api_core import exceptions as core_exceptions -from google.api_core import retry as retries -from google.api_core.exceptions import Aborted -from google.api_core.exceptions import DeadlineExceeded -from google.api_core.exceptions import ServiceUnavailable -from google.api_core.retry import exponential_sleep_generator from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT from google.cloud.bigtable.data import _helpers -from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto -from google.cloud.bigtable.data._async._read_rows import _ResetRow -from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE -from google.cloud.bigtable.data._helpers import RowKeySamples -from google.cloud.bigtable.data._helpers import ShardedQuery -from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data._helpers import _attempt_timeout_generator -from google.cloud.bigtable.data._helpers import _get_retryable_errors -from google.cloud.bigtable.data._helpers import _get_timeouts -from google.cloud.bigtable.data._helpers import _make_metadata -from google.cloud.bigtable.data._helpers import _retry_exception_factory -from google.cloud.bigtable.data.exceptions import FailedMutationEntryError -from google.cloud.bigtable.data.exceptions import FailedQueryShardError -from google.cloud.bigtable.data.exceptions import InvalidChunk -from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup -from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup -from google.cloud.bigtable.data.exceptions import _RowSetComplete -from google.cloud.bigtable.data.mutations import Mutation -from google.cloud.bigtable.data.mutations import RowMutationEntry -from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT -from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule -from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery -from google.cloud.bigtable.data.row import Cell -from google.cloud.bigtable.data.row import Row -from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter -from google.cloud.bigtable.data.row_filters import RowFilter -from google.cloud.bigtable.data.row_filters import RowFilterChain -from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter +from google.cloud.bigtable.data._async.client import TableAsync +from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO -from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledBigtableGrpcTransport, - PooledChannel, +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, +) +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel as AsyncPooledChannel, ) -from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB -from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB -from google.cloud.bigtable_v2.types import RowRange as RowRangePB -from google.cloud.bigtable_v2.types import RowSet as RowSetPB from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest from google.cloud.client import ClientWithProject from google.cloud.environment_vars import BIGTABLE_EMULATOR import google.auth._default import google.auth.credentials -import google.cloud.bigtable.data.exceptions -import google.cloud.bigtable.data.exceptions as bt_exceptions -import google.cloud.bigtable_v2.types.bigtable - - -class _ReadRowsOperation_SyncGen(ABC): - """ - ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream - into a stream of Row objects. - - ReadRowsOperation.merge_row_response_stream takes in a stream of ReadRowsResponse - and turns them into a stream of Row objects using an internal - StateMachine. - - ReadRowsOperation(request, client) handles row merging logic end-to-end, including - performing retries on stream errors. - """ - - __slots__ = ( - "attempt_timeout_gen", - "operation_timeout", - "request", - "table", - "_predicate", - "_metadata", - "_last_yielded_row_key", - "_remaining_count", - ) - - def __init__( - self, - query: ReadRowsQuery, - table: "google.cloud.bigtable.data._sync.client.Table", - operation_timeout: float, - attempt_timeout: float, - retryable_exceptions: Sequence[type[Exception]] = (), - ): - self.attempt_timeout_gen = _helpers._attempt_timeout_generator( - attempt_timeout, operation_timeout - ) - self.operation_timeout = operation_timeout - if isinstance(query, dict): - self.request = ReadRowsRequestPB( - **query, - table_name=table.table_name, - app_profile_id=table.app_profile_id, - ) - else: - self.request = query._to_pb(table) - self.table = table - self._predicate = retries.if_exception_type(*retryable_exceptions) - self._metadata = _helpers._make_metadata(table.table_name, table.app_profile_id) - self._last_yielded_row_key: bytes | None = None - self._remaining_count: int | None = self.request.rows_limit or None - - def start_operation(self) -> Iterable[Row]: - """Start the read_rows operation, retrying on retryable errors.""" - return retries.retry_target_stream( - self._read_rows_attempt, - self._predicate, - exponential_sleep_generator(0.01, 60, multiplier=2), - self.operation_timeout, - exception_factory=_helpers._retry_exception_factory, - ) - - def _read_rows_attempt(self) -> Iterable[Row]: - """ - Attempt a single read_rows rpc call. - This function is intended to be wrapped by retry logic, - which will call this function until it succeeds or - a non-retryable error is raised. - """ - if self._last_yielded_row_key is not None: - try: - self.request.rows = self._revise_request_rowset( - row_set=self.request.rows, - last_seen_row_key=self._last_yielded_row_key, - ) - except _RowSetComplete: - return self.merge_rows(None) - if self._remaining_count is not None: - self.request.rows_limit = self._remaining_count - if self._remaining_count == 0: - return self.merge_rows(None) - gapic_stream = self.table.client._gapic_client.read_rows( - self.request, - timeout=next(self.attempt_timeout_gen), - metadata=self._metadata, - retry=None, - ) - chunked_stream = self.chunk_stream(gapic_stream) - return self.merge_rows(chunked_stream) - - def chunk_stream( - self, stream: None[Iterable[ReadRowsResponsePB]] - ) -> Iterable[ReadRowsResponsePB.CellChunk]: - """process chunks out of raw read_rows stream""" - for resp in stream: - resp = resp._pb - if resp.last_scanned_row_key: - if ( - self._last_yielded_row_key is not None - and resp.last_scanned_row_key <= self._last_yielded_row_key - ): - raise InvalidChunk("last scanned out of order") - self._last_yielded_row_key = resp.last_scanned_row_key - current_key = None - for c in resp.chunks: - if current_key is None: - current_key = c.row_key - if current_key is None: - raise InvalidChunk("first chunk is missing a row key") - elif ( - self._last_yielded_row_key - and current_key <= self._last_yielded_row_key - ): - raise InvalidChunk("row keys should be strictly increasing") - yield c - if c.reset_row: - current_key = None - elif c.commit_row: - self._last_yielded_row_key = current_key - if self._remaining_count is not None: - self._remaining_count -= 1 - if self._remaining_count < 0: - raise InvalidChunk("emit count exceeds row limit") - current_key = None - - @staticmethod - def merge_rows(chunks: Iterable[ReadRowsResponsePB.CellChunk] | None): - """Merge chunks into rows""" - if chunks is None: - return - it = chunks.__iter__() - while True: - try: - c = it.__next__() - except StopIteration: - return - row_key = c.row_key - if not row_key: - raise InvalidChunk("first row chunk is missing key") - cells = [] - family: str | None = None - qualifier: bytes | None = None - try: - while True: - if c.reset_row: - raise _ResetRow(c) - k = c.row_key - f = c.family_name.value - q = c.qualifier.value if c.HasField("qualifier") else None - if k and k != row_key: - raise InvalidChunk("unexpected new row key") - if f: - family = f - if q is not None: - qualifier = q - else: - raise InvalidChunk("new family without qualifier") - elif family is None: - raise InvalidChunk("missing family") - elif q is not None: - if family is None: - raise InvalidChunk("new qualifier without family") - qualifier = q - elif qualifier is None: - raise InvalidChunk("missing qualifier") - ts = c.timestamp_micros - labels = c.labels if c.labels else [] - value = c.value - if c.value_size > 0: - buffer = [value] - while c.value_size > 0: - c = it.__next__() - t = c.timestamp_micros - cl = c.labels - k = c.row_key - if ( - c.HasField("family_name") - and c.family_name.value != family - ): - raise InvalidChunk("family changed mid cell") - if ( - c.HasField("qualifier") - and c.qualifier.value != qualifier - ): - raise InvalidChunk("qualifier changed mid cell") - if t and t != ts: - raise InvalidChunk("timestamp changed mid cell") - if cl and cl != labels: - raise InvalidChunk("labels changed mid cell") - if k and k != row_key: - raise InvalidChunk("row key changed mid cell") - if c.reset_row: - raise _ResetRow(c) - buffer.append(c.value) - value = b"".join(buffer) - cells.append( - Cell(value, row_key, family, qualifier, ts, list(labels)) - ) - if c.commit_row: - yield Row(row_key, cells) - break - c = it.__next__() - except _ResetRow as e: - c = e.chunk - if ( - c.row_key - or c.HasField("family_name") - or c.HasField("qualifier") - or c.timestamp_micros - or c.labels - or c.value - ): - raise InvalidChunk("reset row with data") - continue - except StopIteration: - raise InvalidChunk("premature end of stream") - - @staticmethod - def _revise_request_rowset(row_set: RowSetPB, last_seen_row_key: bytes) -> RowSetPB: - """ - Revise the rows in the request to avoid ones we've already processed. - - Args: - - row_set: the row set from the request - - last_seen_row_key: the last row key encountered - Raises: - - _RowSetComplete: if there are no rows left to process after the revision - """ - if row_set is None or (not row_set.row_ranges and row_set.row_keys is not None): - last_seen = last_seen_row_key - return RowSetPB(row_ranges=[RowRangePB(start_key_open=last_seen)]) - adjusted_keys: list[bytes] = [ - k for k in row_set.row_keys if k > last_seen_row_key - ] - adjusted_ranges: list[RowRangePB] = [] - for row_range in row_set.row_ranges: - end_key = row_range.end_key_closed or row_range.end_key_open or None - if end_key is None or end_key > last_seen_row_key: - new_range = RowRangePB(row_range) - start_key = row_range.start_key_closed or row_range.start_key_open - if start_key is None or start_key <= last_seen_row_key: - new_range.start_key_open = last_seen_row_key - adjusted_ranges.append(new_range) - if len(adjusted_keys) == 0 and len(adjusted_ranges) == 0: - raise _RowSetComplete() - return RowSetPB(row_keys=adjusted_keys, row_ranges=adjusted_ranges) - - -class _MutateRowsOperation_SyncGen(ABC): - """ - MutateRowsOperation manages the logic of sending a set of row mutations, - and retrying on failed entries. It manages this using the _run_attempt - function, which attempts to mutate all outstanding entries, and raises - _MutateRowsIncomplete if any retryable errors are encountered. - - Errors are exposed as a MutationsExceptionGroup, which contains a list of - exceptions organized by the related failed mutation entries. - """ - - def __init__( - self, - gapic_client: "BigtableClient", - table: "google.cloud.bigtable.data._sync.client.Table", - mutation_entries: list["RowMutationEntry"], - operation_timeout: float, - attempt_timeout: float | None, - retryable_exceptions: Sequence[type[Exception]] = (), - ): - """ - Args: - - gapic_client: the client to use for the mutate_rows call - - table: the table associated with the request - - mutation_entries: a list of RowMutationEntry objects to send to the server - - operation_timeout: the timeout to use for the entire operation, in seconds. - - attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds. - If not specified, the request will run until operation_timeout is reached. - """ - total_mutations = sum((len(entry.mutations) for entry in mutation_entries)) - if total_mutations > _MUTATE_ROWS_REQUEST_MUTATION_LIMIT: - raise ValueError( - f"mutate_rows requests can contain at most {_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across all entries. Found {total_mutations}." - ) - metadata = _make_metadata(table.table_name, table.app_profile_id) - self._gapic_fn = functools.partial( - gapic_client.mutate_rows, - table_name=table.table_name, - app_profile_id=table.app_profile_id, - metadata=metadata, - retry=None, - ) - self.is_retryable = retries.if_exception_type( - *retryable_exceptions, bt_exceptions._MutateRowsIncomplete - ) - sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - self._operation = lambda: retries.retry_target( - self._run_attempt, - self.is_retryable, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, - ) - self.timeout_generator = _attempt_timeout_generator( - attempt_timeout, operation_timeout - ) - self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries] - self.remaining_indices = list(range(len(self.mutations))) - self.errors: dict[int, list[Exception]] = {} - - def start(self): - """ - Start the operation, and run until completion - - Raises: - - MutationsExceptionGroup: if any mutations failed - """ - try: - self._operation() - except Exception as exc: - incomplete_indices = self.remaining_indices.copy() - for idx in incomplete_indices: - self._handle_entry_error(idx, exc) - finally: - all_errors: list[Exception] = [] - for idx, exc_list in self.errors.items(): - if len(exc_list) == 0: - raise core_exceptions.ClientError( - f"Mutation {idx} failed with no associated errors" - ) - elif len(exc_list) == 1: - cause_exc = exc_list[0] - else: - cause_exc = bt_exceptions.RetryExceptionGroup(exc_list) - entry = self.mutations[idx].entry - all_errors.append( - bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc) - ) - if all_errors: - raise bt_exceptions.MutationsExceptionGroup( - all_errors, len(self.mutations) - ) - - def _run_attempt(self): - """ - Run a single attempt of the mutate_rows rpc. - - Raises: - - _MutateRowsIncomplete: if there are failed mutations eligible for - retry after the attempt is complete - - GoogleAPICallError: if the gapic rpc fails - """ - request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] - active_request_indices = { - req_idx: orig_idx - for (req_idx, orig_idx) in enumerate(self.remaining_indices) - } - self.remaining_indices = [] - if not request_entries: - return - try: - result_generator = self._gapic_fn( - timeout=next(self.timeout_generator), - entries=request_entries, - retry=None, - ) - for result_list in result_generator: - for result in result_list.entries: - orig_idx = active_request_indices[result.index] - entry_error = core_exceptions.from_grpc_status( - result.status.code, - result.status.message, - details=result.status.details, - ) - if result.status.code != 0: - self._handle_entry_error(orig_idx, entry_error) - elif orig_idx in self.errors: - del self.errors[orig_idx] - del active_request_indices[result.index] - except Exception as exc: - for idx in active_request_indices.values(): - self._handle_entry_error(idx, exc) - raise - if self.remaining_indices: - raise bt_exceptions._MutateRowsIncomplete - - def _handle_entry_error(self, idx: int, exc: Exception): - """ - Add an exception to the list of exceptions for a given mutation index, - and add the index to the list of remaining indices if the exception is - retryable. - - Args: - - idx: the index of the mutation that failed - - exc: the exception to add to the list - """ - entry = self.mutations[idx].entry - self.errors.setdefault(idx, []).append(exc) - if ( - entry.is_idempotent() - and self.is_retryable(exc) - and (idx not in self.remaining_indices) - ): - self.remaining_indices.append(idx) - - -class MutationsBatcher_SyncGen(ABC): - """ - Allows users to send batches using context manager API: - - Runs mutate_row, mutate_rows, and check_and_mutate_row internally, combining - to use as few network requests as required - - Flushes: - - every flush_interval seconds - - after queue reaches flush_count in quantity - - after queue reaches flush_size_bytes in storage size - - when batcher is closed or destroyed - - async with table.mutations_batcher() as batcher: - for i in range(10): - batcher.add(row, mut) - """ - - def __init__( - self, - table: "google.cloud.bigtable.data._sync.client.Table", - *, - flush_interval: float | None = 5, - flush_limit_mutation_count: int | None = 1000, - flush_limit_bytes: int = 20 * _MB_SIZE, - flow_control_max_mutation_count: int = 100000, - flow_control_max_bytes: int = 100 * _MB_SIZE, - batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - batch_retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - ): - """ - Args: - - table: Table to preform rpc calls - - flush_interval: Automatically flush every flush_interval seconds. - If None, no time-based flushing is performed. - - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count - mutations are added across all entries. If None, this limit is ignored. - - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. - - flow_control_max_mutation_count: Maximum number of inflight mutations. - - flow_control_max_bytes: Maximum number of inflight bytes. - - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. - If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_operation_timeout. - - batch_attempt_timeout: timeout for each individual request, in seconds. - If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_attempt_timeout. - If None, defaults to batch_operation_timeout. - - batch_retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_mutate_rows_retryable_errors. - """ - (self._operation_timeout, self._attempt_timeout) = _get_timeouts( - batch_operation_timeout, batch_attempt_timeout, table - ) - self._retryable_errors: list[type[Exception]] = _get_retryable_errors( - batch_retryable_errors, table - ) - self._closed: threading.Event = threading.Event() - self._table = table - self._staged_entries: list[RowMutationEntry] = [] - (self._staged_count, self._staged_bytes) = (0, 0) - self._flow_control = ( - google.cloud.bigtable.data._sync.mutations_batcher._FlowControl( - flow_control_max_mutation_count, flow_control_max_bytes - ) - ) - self._flush_limit_bytes = flush_limit_bytes - self._flush_limit_count = ( - flush_limit_mutation_count - if flush_limit_mutation_count is not None - else float("inf") - ) - self._flush_timer = self._create_bg_task(self._timer_routine, flush_interval) - self._flush_jobs: set[concurrent.futures.Future[None]] = set() - self._entries_processed_since_last_raise: int = 0 - self._exceptions_since_last_raise: int = 0 - self._exception_list_limit: int = 10 - self._oldest_exceptions: list[Exception] = [] - self._newest_exceptions: deque[Exception] = deque( - maxlen=self._exception_list_limit - ) - atexit.register(self._on_exit) - - def _timer_routine(self, interval: float | None) -> None: - raise NotImplementedError("Function not implemented in sync class") - - def append(self, mutation_entry: RowMutationEntry): - """ - Add a new set of mutations to the internal queue - - TODO: return a future to track completion of this entry - - Args: - - mutation_entry: new entry to add to flush queue - Raises: - - RuntimeError if batcher is closed - - ValueError if an invalid mutation type is added - """ - if self._closed.is_set(): - raise RuntimeError("Cannot append to closed MutationsBatcher") - if isinstance(mutation_entry, Mutation): - raise ValueError( - f"invalid mutation type: {type(mutation_entry).__name__}. Only RowMutationEntry objects are supported by batcher" - ) - self._staged_entries.append(mutation_entry) - self._staged_count += len(mutation_entry.mutations) - self._staged_bytes += mutation_entry.size() - if ( - self._staged_count >= self._flush_limit_count - or self._staged_bytes >= self._flush_limit_bytes - ): - self._schedule_flush() - time.sleep(0) - - def _schedule_flush(self) -> concurrent.futures.Future[None] | None: - """Update the flush task to include the latest staged entries""" - if self._staged_entries: - (entries, self._staged_entries) = (self._staged_entries, []) - (self._staged_count, self._staged_bytes) = (0, 0) - new_task = self._create_bg_task(self._flush_internal, entries) - if not new_task.done(): - self._flush_jobs.add(new_task) - new_task.add_done_callback(self._flush_jobs.remove) - return new_task - return None - - def _flush_internal(self, new_entries: list[RowMutationEntry]): - """ - Flushes a set of mutations to the server, and updates internal state - - Args: - - new_entries: list of RowMutationEntry objects to flush - """ - in_process_requests: list[ - concurrent.futures.Future[list[FailedMutationEntryError]] - ] = [] - for batch in self._flow_control.add_to_flow(new_entries): - batch_task = self._create_bg_task(self._execute_mutate_rows, batch) - in_process_requests.append(batch_task) - found_exceptions = self._wait_for_batch_results(*in_process_requests) - self._entries_processed_since_last_raise += len(new_entries) - self._add_exceptions(found_exceptions) - - def _execute_mutate_rows( - self, batch: list[RowMutationEntry] - ) -> list[FailedMutationEntryError]: - """ - Helper to execute mutation operation on a batch - - Args: - - batch: list of RowMutationEntry objects to send to server - - timeout: timeout in seconds. Used as operation_timeout and attempt_timeout. - If not given, will use table defaults - Returns: - - list of FailedMutationEntryError objects for mutations that failed. - FailedMutationEntryError objects will not contain index information - """ - try: - operation = ( - google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation( - self._table.client._gapic_client, - self._table, - batch, - operation_timeout=self._operation_timeout, - attempt_timeout=self._attempt_timeout, - retryable_exceptions=self._retryable_errors, - ) - ) - operation.start() - except MutationsExceptionGroup as e: - for subexc in e.exceptions: - subexc.index = None - return list(e.exceptions) - finally: - self._flow_control.remove_from_flow(batch) - return [] - - def _add_exceptions(self, excs: list[Exception]): - """ - Add new list of exceptions to internal store. To avoid unbounded memory, - the batcher will store the first and last _exception_list_limit exceptions, - and discard any in between. - """ - self._exceptions_since_last_raise += len(excs) - if excs and len(self._oldest_exceptions) < self._exception_list_limit: - addition_count = self._exception_list_limit - len(self._oldest_exceptions) - self._oldest_exceptions.extend(excs[:addition_count]) - excs = excs[addition_count:] - if excs: - self._newest_exceptions.extend(excs[-self._exception_list_limit :]) - - def _raise_exceptions(self): - """ - Raise any unreported exceptions from background flush operations - - Raises: - - MutationsExceptionGroup with all unreported exceptions - """ - if self._oldest_exceptions or self._newest_exceptions: - (oldest, self._oldest_exceptions) = (self._oldest_exceptions, []) - newest = list(self._newest_exceptions) - self._newest_exceptions.clear() - (entry_count, self._entries_processed_since_last_raise) = ( - self._entries_processed_since_last_raise, - 0, - ) - (exc_count, self._exceptions_since_last_raise) = ( - self._exceptions_since_last_raise, - 0, - ) - raise MutationsExceptionGroup.from_truncated_lists( - first_list=oldest, - last_list=newest, - total_excs=exc_count, - entry_count=entry_count, - ) - - def __enter__(self): - """For context manager API""" - return self - - def __exit__(self, exc_type, exc, tb): - """For context manager API""" - self.close() - - @property - def closed(self) -> bool: - """ - Returns: - - True if the batcher is closed, False otherwise - """ - return self._closed.is_set() - - def close(self): - raise NotImplementedError("Function not implemented in sync class") - - def _on_exit(self): - """Called when program is exited. Raises warning if unflushed mutations remain""" - if not self._closed.is_set() and self._staged_entries: - warnings.warn( - f"MutationsBatcher for table {self._table.table_name} was not closed. {len(self._staged_entries)} Unflushed mutations will not be sent to the server." - ) - - @staticmethod - def _create_bg_task(func, *args, **kwargs) -> concurrent.futures.Future[Any]: - raise NotImplementedError("Function not implemented in sync class") - - @staticmethod - def _wait_for_batch_results( - *tasks: concurrent.futures.Future[list[FailedMutationEntryError]] - | concurrent.futures.Future[None], - ) -> list[Exception]: - raise NotImplementedError("Function not implemented in sync class") - - -class _FlowControl_SyncGen(ABC): - """ - Manages flow control for batched mutations. Mutations are registered against - the FlowControl object before being sent, which will block if size or count - limits have reached capacity. As mutations completed, they are removed from - the FlowControl object, which will notify any blocked requests that there - is additional capacity. - - Flow limits are not hard limits. If a single mutation exceeds the configured - limits, it will be allowed as a single batch when the capacity is available. - """ - - def __init__(self, max_mutation_count: int, max_mutation_bytes: int): - """ - Args: - - max_mutation_count: maximum number of mutations to send in a single rpc. - This corresponds to individual mutations in a single RowMutationEntry. - - max_mutation_bytes: maximum number of bytes to send in a single rpc. - """ - self._max_mutation_count = max_mutation_count - self._max_mutation_bytes = max_mutation_bytes - if self._max_mutation_count < 1: - raise ValueError("max_mutation_count must be greater than 0") - if self._max_mutation_bytes < 1: - raise ValueError("max_mutation_bytes must be greater than 0") - self._capacity_condition = threading.Condition() - self._in_flight_mutation_count = 0 - self._in_flight_mutation_bytes = 0 - - def _has_capacity(self, additional_count: int, additional_size: int) -> bool: - """ - Checks if there is capacity to send a new entry with the given size and count - - FlowControl limits are not hard limits. If a single mutation exceeds - the configured flow limits, it will be sent in a single batch when - previous batches have completed. - - Args: - - additional_count: number of mutations in the pending entry - - additional_size: size of the pending entry - Returns: - - True if there is capacity to send the pending entry, False otherwise - """ - acceptable_size = max(self._max_mutation_bytes, additional_size) - acceptable_count = max(self._max_mutation_count, additional_count) - new_size = self._in_flight_mutation_bytes + additional_size - new_count = self._in_flight_mutation_count + additional_count - return new_size <= acceptable_size and new_count <= acceptable_count - - def remove_from_flow( - self, mutations: RowMutationEntry | list[RowMutationEntry] - ) -> None: - """ - Removes mutations from flow control. This method should be called once - for each mutation that was sent to add_to_flow, after the corresponding - operation is complete. - - Args: - - mutations: mutation or list of mutations to remove from flow control - """ - if not isinstance(mutations, list): - mutations = [mutations] - total_count = sum((len(entry.mutations) for entry in mutations)) - total_size = sum((entry.size() for entry in mutations)) - self._in_flight_mutation_count -= total_count - self._in_flight_mutation_bytes -= total_size - with self._capacity_condition: - self._capacity_condition.notify_all() - - def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): - """ - Generator function that registers mutations with flow control. As mutations - are accepted into the flow control, they are yielded back to the caller, - to be sent in a batch. If the flow control is at capacity, the generator - will block until there is capacity available. - - Args: - - mutations: list mutations to break up into batches - Yields: - - list of mutations that have reserved space in the flow control. - Each batch contains at least one mutation. - """ - if not isinstance(mutations, list): - mutations = [mutations] - start_idx = 0 - end_idx = 0 - while end_idx < len(mutations): - start_idx = end_idx - batch_mutation_count = 0 - with self._capacity_condition: - while end_idx < len(mutations): - next_entry = mutations[end_idx] - next_size = next_entry.size() - next_count = len(next_entry.mutations) - if ( - self._has_capacity(next_count, next_size) - and batch_mutation_count + next_count - <= _MUTATE_ROWS_REQUEST_MUTATION_LIMIT - ): - end_idx += 1 - batch_mutation_count += next_count - self._in_flight_mutation_bytes += next_size - self._in_flight_mutation_count += next_count - elif start_idx != end_idx: - break - else: - self._capacity_condition.wait_for( - lambda: self._has_capacity(next_count, next_size) - ) - yield mutations[start_idx:end_idx] +@CrossSync.sync_output( + "google.cloud.bigtable.data._sync._autogen.BigtableDataClient_SyncGen" +) class BigtableDataClient_SyncGen(ClientWithProject, ABC): def __init__( self, @@ -877,7 +70,7 @@ def __init__( Client should be created within an async context (running event loop) - Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not + Warning: BigtableDataClientAsync is currently in preview, and is not yet recommended for production use. Args: @@ -899,7 +92,7 @@ def __init__( - ValueError if pool_size is less than 1 """ transport_str = f"bt-{self._client_version()}-{pool_size}" - transport = PooledBigtableGrpcTransport.with_fixed_size(pool_size) + transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) BigtableClientMeta._transport_registry[transport_str] = transport client_info = DEFAULT_CLIENT_INFO client_info.client_library_version = self._client_version() @@ -920,25 +113,28 @@ def __init__( project=project, client_options=client_options, ) - self._gapic_client = BigtableClient( + self._gapic_client = BigtableAsyncClient( transport=transport_str, credentials=credentials, client_options=client_options, client_info=client_info, ) - self._is_closed = threading.Event() - self.transport = cast(PooledBigtableGrpcTransport, self._gapic_client.transport) + self._is_closed = asyncio.Event() + self.transport = cast( + PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport + ) self._active_instances: Set[_helpers._WarmedInstanceKey] = set() self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() - self._channel_refresh_tasks: list[concurrent.futures.Future[None]] = [] + self._channel_refresh_tasks: list[asyncio.Task[None]] = [] + self._executor = concurrent.futures.ThreadPoolExecutor() if not False else None if self._emulator_host is not None: warnings.warn( "Connecting to Bigtable emulator at {}".format(self._emulator_host), RuntimeWarning, stacklevel=2, ) - self.transport._grpc_channel = PooledChannel( + self.transport._grpc_channel = AsyncPooledChannel( pool_size=pool_size, host=self._emulator_host, insecure=True ) self.transport._stubs = {} @@ -955,10 +151,44 @@ def __init__( @staticmethod def _client_version() -> str: - raise NotImplementedError("Function not implemented in sync class") + """Helper function to return the client version string for this client""" + if False: + return f"{google.cloud.bigtable.__version__}-data-async" + else: + return f"{google.cloud.bigtable.__version__}-data" def _start_background_channel_refresh(self) -> None: - raise NotImplementedError("Function not implemented in sync class") + """ + Starts a background task to ping and warm each channel in the pool + Raises: + - RuntimeError if not called in an asyncio event loop + """ + if ( + not self._channel_refresh_tasks + and (not self._emulator_host) + and (not self._is_closed.is_set()) + ): + for channel_idx in range(self.transport.pool_size): + refresh_task = CrossSync.create_task( + self._manage_channel, + channel_idx, + sync_executor=self._executor, + task_name=f"{self.__class__.__name__} channel refresh {channel_idx}", + ) + self._channel_refresh_tasks.append(refresh_task) + refresh_task.add_done_callback( + lambda _: self._channel_refresh_tasks.remove(refresh_task) + ) + + def close(self, timeout: float | None = None): + """Cancel all background tasks""" + self._is_closed.set() + for task in self._channel_refresh_tasks: + task.cancel() + self.transport.close() + if self._executor: + self._executor.shutdown(wait=False) + CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None @@ -995,10 +225,10 @@ def _ping_and_warm_instances( ) for (instance_name, table_name, app_profile_id) in instance_list ] - return self._execute_ping_and_warms(*partial_list) - - def _execute_ping_and_warms(self, *fns) -> list[BaseException | None]: - raise NotImplementedError("Function not implemented in sync class") + result_list = CrossSync.gather_partials( + partial_list, return_exceptions=True, sync_executor=self._executor + ) + return [r or None for r in result_list] def _manage_channel( self, @@ -1026,6 +256,7 @@ def _manage_channel( grace_period: time to allow previous channel to serve existing requests before closing, in seconds """ + sleep_fn = asyncio.sleep if False else self._is_closed.wait first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) @@ -1034,7 +265,7 @@ def _manage_channel( channel = self.transport.channels[channel_idx] self._ping_and_warm_instances(channel) while not self._is_closed.is_set(): - self._is_closed.wait(next_sleep) + sleep_fn(next_sleep) if self._is_closed.is_set(): break new_channel = self.transport.grpc_channel._create_channel() @@ -1049,9 +280,7 @@ def _manage_channel( next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = next_refresh - (time.monotonic() - start_timestamp) - def _register_instance( - self, instance_id: str, owner: google.cloud.bigtable.data._sync.client.Table - ) -> None: + def _register_instance(self, instance_id: str, owner: TableAsync) -> None: """ Registers an instance with the client, and warms the channel pool for the instance @@ -1079,7 +308,7 @@ def _register_instance( self._start_background_channel_refresh() def _remove_instance_registration( - self, instance_id: str, owner: google.cloud.bigtable.data._sync.client.Table + self, instance_id: str, owner: TableAsync ) -> bool: """ Removes an instance from the client's registered instances, to prevent @@ -1108,12 +337,10 @@ def _remove_instance_registration( except KeyError: return False - def get_table( - self, instance_id: str, table_id: str, *args, **kwargs - ) -> google.cloud.bigtable.data._sync.client.Table: + def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: """ Returns a table instance for making data API requests. All arguments are passed - directly to the google.cloud.bigtable.data._sync.client.Table constructor. + directly to the TableAsync constructor. Args: instance_id: The Bigtable instance ID to associate with this client. @@ -1145,9 +372,7 @@ def get_table( encountered during all other operations. Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) """ - return google.cloud.bigtable.data._sync.client.Table( - self, instance_id, table_id, *args, **kwargs - ) + return TableAsync(self, instance_id, table_id, *args, **kwargs) def __enter__(self): self._start_background_channel_refresh() @@ -1156,811 +381,3 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() self._gapic_client.__exit__(exc_type, exc_val, exc_tb) - - -class Table_SyncGen(ABC): - """ - Main Data API surface - - Table object maintains table_id, and app_profile_id context, and passes them with - each call - """ - - def __init__( - self, - client: google.cloud.bigtable.data._sync.client.BigtableDataClient, - instance_id: str, - table_id: str, - app_profile_id: str | None = None, - *, - default_read_rows_operation_timeout: float = 600, - default_read_rows_attempt_timeout: float | None = 20, - default_mutate_rows_operation_timeout: float = 600, - default_mutate_rows_attempt_timeout: float | None = 60, - default_operation_timeout: float = 60, - default_attempt_timeout: float | None = 20, - default_read_rows_retryable_errors: Sequence[type[Exception]] = ( - DeadlineExceeded, - ServiceUnavailable, - Aborted, - ), - default_mutate_rows_retryable_errors: Sequence[type[Exception]] = ( - DeadlineExceeded, - ServiceUnavailable, - ), - default_retryable_errors: Sequence[type[Exception]] = ( - DeadlineExceeded, - ServiceUnavailable, - ), - ): - """ - Initialize a Table instance - - Must be created within an async context (running event loop) - - Args: - instance_id: The Bigtable instance ID to associate with this client. - instance_id is combined with the client's project to fully - specify the instance - table_id: The ID of the table. table_id is combined with the - instance_id and the client's project to fully specify the table - app_profile_id: The app profile to associate with requests. - https://cloud.google.com/bigtable/docs/app-profiles - default_read_rows_operation_timeout: The default timeout for read rows - operations, in seconds. If not set, defaults to 600 seconds (10 minutes) - default_read_rows_attempt_timeout: The default timeout for individual - read rows rpc requests, in seconds. If not set, defaults to 20 seconds - default_mutate_rows_operation_timeout: The default timeout for mutate rows - operations, in seconds. If not set, defaults to 600 seconds (10 minutes) - default_mutate_rows_attempt_timeout: The default timeout for individual - mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds - default_operation_timeout: The default timeout for all other operations, in - seconds. If not set, defaults to 60 seconds - default_attempt_timeout: The default timeout for all other individual rpc - requests, in seconds. If not set, defaults to 20 seconds - default_read_rows_retryable_errors: a list of errors that will be retried - if encountered during read_rows and related operations. - Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) - default_mutate_rows_retryable_errors: a list of errors that will be retried - if encountered during mutate_rows and related operations. - Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) - default_retryable_errors: a list of errors that will be retried if - encountered during all other operations. - Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) - Raises: - - RuntimeError if called outside of an async context (no running event loop) - """ - _helpers._validate_timeouts( - default_operation_timeout, default_attempt_timeout, allow_none=True - ) - _helpers._validate_timeouts( - default_read_rows_operation_timeout, - default_read_rows_attempt_timeout, - allow_none=True, - ) - _helpers._validate_timeouts( - default_mutate_rows_operation_timeout, - default_mutate_rows_attempt_timeout, - allow_none=True, - ) - self.client = client - self.instance_id = instance_id - self.instance_name = self.client._gapic_client.instance_path( - self.client.project, instance_id - ) - self.table_id = table_id - self.table_name = self.client._gapic_client.table_path( - self.client.project, instance_id, table_id - ) - self.app_profile_id = app_profile_id - self.default_operation_timeout = default_operation_timeout - self.default_attempt_timeout = default_attempt_timeout - self.default_read_rows_operation_timeout = default_read_rows_operation_timeout - self.default_read_rows_attempt_timeout = default_read_rows_attempt_timeout - self.default_mutate_rows_operation_timeout = ( - default_mutate_rows_operation_timeout - ) - self.default_mutate_rows_attempt_timeout = default_mutate_rows_attempt_timeout - self.default_read_rows_retryable_errors = ( - default_read_rows_retryable_errors or () - ) - self.default_mutate_rows_retryable_errors = ( - default_mutate_rows_retryable_errors or () - ) - self.default_retryable_errors = default_retryable_errors or () - self._register_instance_future: concurrent.futures.Future[ - None - ] = self._register_with_client() - - def _register_with_client(self) -> concurrent.futures.Future[None]: - raise NotImplementedError("Function not implemented in sync class") - - def read_rows_stream( - self, - query: ReadRowsQuery, - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - ) -> Iterable[Row]: - """ - Read a set of rows from the table, based on the specified query. - Returns an iterator to asynchronously stream back row data. - - Failed requests within operation_timeout will be retried based on the - retryable_errors list until operation_timeout is reached. - - Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not - yet recommended for production use. - - Args: - - query: contains details about which rows to return - - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_read_rows_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_read_rows_attempt_timeout. - If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_read_rows_retryable_errors - Returns: - - an asynchronous iterator that yields rows returned by the query - Raises: - - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions - from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error - """ - (operation_timeout, attempt_timeout) = _helpers._get_timeouts( - operation_timeout, attempt_timeout, self - ) - retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - row_merger = google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation( - query, - self, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_exceptions=retryable_excs, - ) - return row_merger.start_operation() - - def read_rows( - self, - query: ReadRowsQuery, - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - ) -> list[Row]: - """ - Read a set of rows from the table, based on the specified query. - Retruns results as a list of Row objects when the request is complete. - For streamed results, use read_rows_stream. - - Failed requests within operation_timeout will be retried based on the - retryable_errors list until operation_timeout is reached. - - Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not - yet recommended for production use. - - Args: - - query: contains details about which rows to return - - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_read_rows_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_read_rows_attempt_timeout. - If None, defaults to operation_timeout. - If None, defaults to the Table's default_read_rows_attempt_timeout, - or the operation_timeout if that is also None. - - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_read_rows_retryable_errors. - Returns: - - a list of Rows returned by the query - Raises: - - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions - from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error - """ - row_generator = self.read_rows_stream( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, - ) - return [row for row in row_generator] - - def read_row( - self, - row_key: str | bytes, - *, - row_filter: RowFilter | None = None, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - ) -> Row | None: - """ - Read a single row from the table, based on the specified key. - - Failed requests within operation_timeout will be retried based on the - retryable_errors list until operation_timeout is reached. - - Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not - yet recommended for production use. - - Args: - - query: contains details about which rows to return - - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_read_rows_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_read_rows_attempt_timeout. - If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_read_rows_retryable_errors. - Returns: - - a Row object if the row exists, otherwise None - Raises: - - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions - from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error - """ - if row_key is None: - raise ValueError("row_key must be string or bytes") - query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) - results = self.read_rows( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, - ) - if len(results) == 0: - return None - return results[0] - - def read_rows_sharded( - self, - sharded_query: ShardedQuery, - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - ) -> list[Row]: - """ - Runs a sharded query in parallel, then return the results in a single list. - Results will be returned in the order of the input queries. - - This function is intended to be run on the results on a query.shard() call: - - ``` - table_shard_keys = await table.sample_row_keys() - query = ReadRowsQuery(...) - shard_queries = query.shard(table_shard_keys) - results = await table.read_rows_sharded(shard_queries) - ``` - - Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not - yet recommended for production use. - - Args: - - sharded_query: a sharded query to execute - - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_read_rows_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_read_rows_attempt_timeout. - If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_read_rows_retryable_errors. - Raises: - - ShardedReadRowsExceptionGroup: if any of the queries failed - - ValueError: if the query_list is empty - """ - if not sharded_query: - raise ValueError("empty sharded_query") - (operation_timeout, attempt_timeout) = _helpers._get_timeouts( - operation_timeout, attempt_timeout, self - ) - timeout_generator = _helpers._attempt_timeout_generator( - operation_timeout, operation_timeout - ) - batched_queries = [ - sharded_query[i : i + _helpers._CONCURRENCY_LIMIT] - for i in range(0, len(sharded_query), _helpers._CONCURRENCY_LIMIT) - ] - results_list = [] - error_dict = {} - shard_idx = 0 - for batch in batched_queries: - batch_operation_timeout = next(timeout_generator) - batch_kwargs_list = [ - { - "query": query, - "operation_timeout": batch_operation_timeout, - "attempt_timeout": min(attempt_timeout, batch_operation_timeout), - "retryable_errors": retryable_errors, - } - for query in batch - ] - batch_result = self._shard_batch_helper(batch_kwargs_list) - for result in batch_result: - if isinstance(result, Exception): - error_dict[shard_idx] = result - elif isinstance(result, BaseException): - raise result - else: - results_list.extend(result) - shard_idx += 1 - if error_dict: - raise ShardedReadRowsExceptionGroup( - [ - FailedQueryShardError(idx, sharded_query[idx], e) - for (idx, e) in error_dict.items() - ], - results_list, - len(sharded_query), - ) - return results_list - - def _shard_batch_helper( - self, kwargs_list: list[dict] - ) -> list[list[Row] | BaseException]: - raise NotImplementedError("Function not implemented in sync class") - - def row_exists( - self, - row_key: str | bytes, - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - ) -> bool: - """ - Return a boolean indicating whether the specified row exists in the table. - uses the filters: chain(limit cells per row = 1, strip value) - - Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not - yet recommended for production use. - - Args: - - row_key: the key of the row to check - - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_read_rows_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_read_rows_attempt_timeout. - If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_read_rows_retryable_errors. - Returns: - - a bool indicating whether the row exists - Raises: - - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions - from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error - """ - if row_key is None: - raise ValueError("row_key must be string or bytes") - strip_filter = StripValueTransformerFilter(flag=True) - limit_filter = CellsRowLimitFilter(1) - chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) - query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) - results = self.read_rows( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, - ) - return len(results) > 0 - - def sample_row_keys( - self, - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - ) -> RowKeySamples: - """ - Return a set of RowKeySamples that delimit contiguous sections of the table of - approximately equal size - - RowKeySamples output can be used with ReadRowsQuery.shard() to create a sharded query that - can be parallelized across multiple backend nodes read_rows and read_rows_stream - requests will call sample_row_keys internally for this purpose when sharding is enabled - - RowKeySamples is simply a type alias for list[tuple[bytes, int]]; a list of - row_keys, along with offset positions in the table - - Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not - yet recommended for production use. - - Args: - - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget.i - Defaults to the Table's default_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_attempt_timeout. - If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_retryable_errors. - Returns: - - a set of RowKeySamples the delimit contiguous sections of the table - Raises: - - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions - from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error - """ - (operation_timeout, attempt_timeout) = _helpers._get_timeouts( - operation_timeout, attempt_timeout, self - ) - attempt_timeout_gen = _helpers._attempt_timeout_generator( - attempt_timeout, operation_timeout - ) - retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - predicate = retries.if_exception_type(*retryable_excs) - sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) - - def execute_rpc(): - results = self.client._gapic_client.sample_row_keys( - table_name=self.table_name, - app_profile_id=self.app_profile_id, - timeout=next(attempt_timeout_gen), - metadata=metadata, - retry=None, - ) - return [(s.row_key, s.offset_bytes) for s in results] - - return retries.retry_target( - execute_rpc, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_helpers._retry_exception_factory, - ) - - def mutations_batcher( - self, - *, - flush_interval: float | None = 5, - flush_limit_mutation_count: int | None = 1000, - flush_limit_bytes: int = 20 * _MB_SIZE, - flow_control_max_mutation_count: int = 100000, - flow_control_max_bytes: int = 100 * _MB_SIZE, - batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - batch_retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - ) -> google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher: - """ - Returns a new mutations batcher instance. - - Can be used to iteratively add mutations that are flushed as a group, - to avoid excess network calls - - Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not - yet recommended for production use. - - Args: - - flush_interval: Automatically flush every flush_interval seconds. If None, - a table default will be used - - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count - mutations are added across all entries. If None, this limit is ignored. - - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. - - flow_control_max_mutation_count: Maximum number of inflight mutations. - - flow_control_max_bytes: Maximum number of inflight bytes. - - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. - Defaults to the Table's default_mutate_rows_operation_timeout - - batch_attempt_timeout: timeout for each individual request, in seconds. - Defaults to the Table's default_mutate_rows_attempt_timeout. - If None, defaults to batch_operation_timeout. - - batch_retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_mutate_rows_retryable_errors. - Returns: - - a google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher context manager that can batch requests - """ - return google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher( - self, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_mutation_count, - flush_limit_bytes=flush_limit_bytes, - flow_control_max_mutation_count=flow_control_max_mutation_count, - flow_control_max_bytes=flow_control_max_bytes, - batch_operation_timeout=batch_operation_timeout, - batch_attempt_timeout=batch_attempt_timeout, - batch_retryable_errors=batch_retryable_errors, - ) - - def mutate_row( - self, - row_key: str | bytes, - mutations: list[Mutation] | Mutation, - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - ): - """ - Mutates a row atomically. - - Cells already present in the row are left unchanged unless explicitly changed - by ``mutation``. - - Idempotent operations (i.e, all mutations have an explicit timestamp) will be - retried on server failure. Non-idempotent operations will not. - - Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not - yet recommended for production use. - - Args: - - row_key: the row to apply mutations to - - mutations: the set of mutations to apply to the row - - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_attempt_timeout. - If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. - Only idempotent mutations will be retried. Defaults to the Table's - default_retryable_errors. - Raises: - - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing all - GoogleAPIError exceptions from any retries that failed - - GoogleAPIError: raised on non-idempotent operations that cannot be - safely retried. - - ValueError if invalid arguments are provided - """ - (operation_timeout, attempt_timeout) = _helpers._get_timeouts( - operation_timeout, attempt_timeout, self - ) - if not mutations: - raise ValueError("No mutations provided") - mutations_list = mutations if isinstance(mutations, list) else [mutations] - if all((mutation.is_idempotent() for mutation in mutations_list)): - predicate = retries.if_exception_type( - *_helpers._get_retryable_errors(retryable_errors, self) - ) - else: - predicate = retries.if_exception_type() - sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - target = partial( - self.client._gapic_client.mutate_row, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - mutations=[mutation._to_pb() for mutation in mutations_list], - table_name=self.table_name, - app_profile_id=self.app_profile_id, - timeout=attempt_timeout, - metadata=_helpers._make_metadata(self.table_name, self.app_profile_id), - retry=None, - ) - return retries.retry_target( - target, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_helpers._retry_exception_factory, - ) - - def bulk_mutate_rows( - self, - mutation_entries: list[RowMutationEntry], - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - ): - """ - Applies mutations for multiple rows in a single batched request. - - Each individual RowMutationEntry is applied atomically, but separate entries - may be applied in arbitrary order (even for entries targetting the same row) - In total, the row_mutations can contain at most 100000 individual mutations - across all entries - - Idempotent entries (i.e., entries with mutations with explicit timestamps) - will be retried on failure. Non-idempotent will not, and will reported in a - raised exception group - - Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not - yet recommended for production use. - - Args: - - mutation_entries: the batches of mutations to apply - Each entry will be applied atomically, but entries will be applied - in arbitrary order - - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_mutate_rows_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_mutate_rows_attempt_timeout. - If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_mutate_rows_retryable_errors - Raises: - - MutationsExceptionGroup if one or more mutations fails - Contains details about any failed entries in .exceptions - - ValueError if invalid arguments are provided - """ - (operation_timeout, attempt_timeout) = _helpers._get_timeouts( - operation_timeout, attempt_timeout, self - ) - retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - operation = google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation( - self.client._gapic_client, - self, - mutation_entries, - operation_timeout, - attempt_timeout, - retryable_exceptions=retryable_excs, - ) - operation.start() - - def check_and_mutate_row( - self, - row_key: str | bytes, - predicate: RowFilter | None, - *, - true_case_mutations: Mutation | list[Mutation] | None = None, - false_case_mutations: Mutation | list[Mutation] | None = None, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - ) -> bool: - """ - Mutates a row atomically based on the output of a predicate filter - - Non-idempotent operation: will not be retried - - Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not - yet recommended for production use. - - Args: - - row_key: the key of the row to mutate - - predicate: the filter to be applied to the contents of the specified row. - Depending on whether or not any results are yielded, - either true_case_mutations or false_case_mutations will be executed. - If None, checks that the row contains any values at all. - - true_case_mutations: - Changes to be atomically applied to the specified row if - predicate yields at least one cell when - applied to row_key. Entries are applied in order, - meaning that earlier mutations can be masked by later - ones. Must contain at least one entry if - false_case_mutations is empty, and at most 100000. - - false_case_mutations: - Changes to be atomically applied to the specified row if - predicate_filter does not yield any cells when - applied to row_key. Entries are applied in order, - meaning that earlier mutations can be masked by later - ones. Must contain at least one entry if - `true_case_mutations is empty, and at most 100000. - - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will not be retried. Defaults to the Table's default_operation_timeout - Returns: - - bool indicating whether the predicate was true or false - Raises: - - GoogleAPIError exceptions from grpc call - """ - (operation_timeout, _) = _helpers._get_timeouts(operation_timeout, None, self) - if true_case_mutations is not None and ( - not isinstance(true_case_mutations, list) - ): - true_case_mutations = [true_case_mutations] - true_case_list = [m._to_pb() for m in true_case_mutations or []] - if false_case_mutations is not None and ( - not isinstance(false_case_mutations, list) - ): - false_case_mutations = [false_case_mutations] - false_case_list = [m._to_pb() for m in false_case_mutations or []] - metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) - result = self.client._gapic_client.check_and_mutate_row( - true_mutations=true_case_list, - false_mutations=false_case_list, - predicate_filter=predicate._to_pb() if predicate is not None else None, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, - ) - return result.predicate_matched - - def read_modify_write_row( - self, - row_key: str | bytes, - rules: ReadModifyWriteRule | list[ReadModifyWriteRule], - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - ) -> Row: - """ - Reads and modifies a row atomically according to input ReadModifyWriteRules, - and returns the contents of all modified cells - - The new value for the timestamp is the greater of the existing timestamp or - the current server time. - - Non-idempotent operation: will not be retried - - Warning: google.cloud.bigtable.data._sync.client.BigtableDataClient is currently in preview, and is not - yet recommended for production use. - - Args: - - row_key: the key of the row to apply read/modify/write rules to - - rules: A rule or set of rules to apply to the row. - Rules are applied in order, meaning that earlier rules will affect the - results of later ones. - - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will not be retried. - Defaults to the Table's default_operation_timeout. - Returns: - - Row: containing cell data that was modified as part of the - operation - Raises: - - GoogleAPIError exceptions from grpc call - - ValueError if invalid arguments are provided - """ - (operation_timeout, _) = _helpers._get_timeouts(operation_timeout, None, self) - if operation_timeout <= 0: - raise ValueError("operation_timeout must be greater than 0") - if rules is not None and (not isinstance(rules, list)): - rules = [rules] - if not rules: - raise ValueError("rules must contain at least one item") - metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) - result = self.client._gapic_client.read_modify_write_row( - rules=[rule._to_pb() for rule in rules], - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, - ) - return Row._from_pb(result.row) - - def close(self): - """Called to close the Table instance and release any resources held by it.""" - if self._register_instance_future: - self._register_instance_future.cancel() - self.client._remove_instance_registration(self.instance_id, self) - - def __enter__(self): - raise NotImplementedError("Function not implemented in sync class") - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Implement async context manager protocol - - Unregister this instance with the client, so that - grpc channels will no longer be warmed - """ - self.close() diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index dbe8f1c1c..d648d601a 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -17,7 +17,13 @@ class CrossSync: @classmethod def sync_output(cls, sync_path): # return the async class unchanged - return lambda async_cls: async_cls + def decorator(async_cls): + async_cls.cross_sync_enabled = True + async_cls.cross_sync_import_path = sync_path + async_cls.cross_sync_class_name = sync_path.rsplit('.', 1)[-1] + async_cls.cross_sync_file_path = "/".join(sync_path.split(".")[:-1]) + ".py" + return async_cls + return decorator @staticmethod async def gather_partials(partial_list, return_exceptions=False, sync_executor=None): diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 3d96d033b..aae969b30 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -21,6 +21,7 @@ import importlib import yaml from pathlib import Path +import os from black import format_str, FileMode import autoflake @@ -36,11 +37,11 @@ class AsyncToSyncTransformer(ast.NodeTransformer): outside of this autogeneration system are always applied """ - def __init__(self, *, name=None, asyncio_replacements=None, text_replacements=None, drop_methods=None, pass_methods=None, error_methods=None, replace_methods=None): + def __init__(self, *, name=None, cross_sync_replacements=None, text_replacements=None, drop_methods=None, pass_methods=None, error_methods=None, replace_methods=None): """ Args: - name: the name of the class being processed. Just used in exceptions - - asyncio_replacements: asyncio functionality to replace + - cross_sync_replacements: CrossSync functionality to replace - text_replacements: dict of text to replace directly in the source code and docstrings - drop_methods: list of method names to drop from the class - pass_methods: list of method names to replace with "pass" in the class @@ -48,7 +49,15 @@ def __init__(self, *, name=None, asyncio_replacements=None, text_replacements=No - replace_methods: dict of method names to replace with custom code """ self.name = name - self.asyncio_replacements = asyncio_replacements or {} + self.cross_sync_replacements = cross_sync_replacements or { + "sleep": "time.sleep", + "Queue": "queue.Queue", + "Condition": "threading.Condition", + "Future": "concurrent.futures.Future", + "Task": "concurrent.futures.Future", + "Event": "threading.Event", + "is_async": "False", + } self.text_replacements = text_replacements or {} self.drop_methods = drop_methods or [] self.pass_methods = pass_methods or [] @@ -100,19 +109,20 @@ def visit_AsyncFunctionDef(self, node): if len(parsed.body) > 0: new_body.append(parsed.body[0]) node.body = new_body - else: + # else: # check if the function contains non-replaced usage of asyncio - func_ast = ast.parse(ast.unparse(node)) - for n in ast.walk(func_ast): - if isinstance(n, ast.Call) \ - and isinstance(n.func, ast.Attribute) \ - and isinstance(n.func.value, ast.Name) \ - and n.func.value.id == "asyncio" \ - and n.func.attr not in self.asyncio_replacements: - path_str = f"{self.name}.{node.name}" if self.name else node.name - raise RuntimeError( - f"{path_str} contains unhandled asyncio calls: {n.func.attr}. Add method to drop_methods, pass_methods, or error_methods to handle safely." - ) + # func_ast = ast.parse(ast.unparse(node)) + # for n in ast.walk(func_ast): + # if isinstance(n, ast.Call) \ + # and isinstance(n.func, ast.Attribute) \ + # and isinstance(n.func.value, ast.Name) \ + # and n.func.value.id == "CrossSync" \ + # and n.func.attr not in self.cross_sync_replacements: + # path_str = f"{self.name}.{node.name}" if self.name else node.name + # breakpoint() + # raise RuntimeError( + # f"{path_str} contains unhandled asyncio calls: {n.func.attr}. Add method to drop_methods, pass_methods, or error_methods to handle safely." + # ) # remove pytest.mark.asyncio decorator if hasattr(node, "decorator_list"): # TODO: make generic @@ -153,10 +163,10 @@ def visit_Attribute(self, node): if ( isinstance(node.value, ast.Name) and isinstance(node.value.ctx, ast.Load) - and node.value.id == "asyncio" - and node.attr in self.asyncio_replacements + and node.value.id == "CrossSync" + and node.attr in self.cross_sync_replacements ): - replacement = self.asyncio_replacements[node.attr] + replacement = self.cross_sync_replacements[node.attr] return ast.copy_location(ast.parse(replacement, mode="eval").body, node) fixed = ast.copy_location( ast.Attribute( @@ -248,7 +258,7 @@ def _create_error_node(node, error_msg): def get_imports(self, filename): """ - Get the imports from a file, and do a find-and-replace against asyncio_replacements + Get the imports from a file, and do a find-and-replace against cross_sync_replacements """ imports = set() with open(filename, "r") as f: @@ -258,14 +268,14 @@ def get_imports(self, filename): for alias in node.names: if isinstance(node, ast.Import): # import statments - new_import = self.asyncio_replacements.get(alias.name, alias.name) + new_import = self.cross_sync_replacements.get(alias.name, alias.name) imports.add(ast.parse(f"import {new_import}").body[0]) else: # import from statements # break into individual components full_path = f"{node.module}.{alias.name}" - if full_path in self.asyncio_replacements: - full_path = self.asyncio_replacements[full_path] + if full_path in self.cross_sync_replacements: + full_path = self.cross_sync_replacements[full_path] module, name = full_path.rsplit(".", 1) # don't import from same file if module == ".": @@ -378,14 +388,81 @@ def transform_from_config(config_dict: dict): if __name__ == "__main__": - for load_path in ["./google/cloud/bigtable/data/_sync/sync_gen.yaml", "./google/cloud/bigtable/data/_sync/unit_tests.yaml"]: - config = yaml.safe_load(Path(load_path).read_text()) + # for load_path in ["./google/cloud/bigtable/data/_sync/sync_gen.yaml", "./google/cloud/bigtable/data/_sync/unit_tests.yaml"]: + # for load_path in ["./google/cloud/bigtable/data/_sync/sync_gen.yaml"]: + # config = yaml.safe_load(Path(load_path).read_text()) - save_path = config.get("save_path") - code = transform_from_config(config) + # save_path = config.get("save_path") + # code = transform_from_config(config) - if save_path is not None: - with open(save_path, "w") as f: - f.write(code) + # if save_path is not None: + # with open(save_path, "w") as f: + # f.write(code) + # find all classes in the library + import google.cloud.bigtable.data as data_lib + lib_classes = inspect.getmembers(data_lib, inspect.isclass) + # keep only those with CrossSync annotation + enabled_classes = [c[1] for c in lib_classes if hasattr(c[1], "cross_sync_enabled")] + # bucket classes by output location + all_paths = {c.cross_sync_file_path for c in enabled_classes} + class_map = {loc: [c for c in enabled_classes if c.cross_sync_file_path == loc] for loc in all_paths} + # generate sync code for each class + for output_file in class_map.keys(): + # initialize new tree and import list + combined_tree = ast.parse("") + combined_imports = set() + for async_class in class_map[output_file]: + class_dict = { + "text_replacements": { + "__anext__": "__next__", + "__aenter__": "__enter__", + "__aexit__": "__exit__", + "__aiter__": "__iter__", + "aclose": "close", + "AsyncIterable": "Iterable", + "AsyncIterator": "Iterator", + "StopAsyncIteration": "StopIteration", + "Awaitable": None, + "CrossSync.Event": "threading.Event", + }, + "autogen_sync_name": async_class.cross_sync_class_name, + } + tree_body, imports = transform_class(async_class, **class_dict) + # update combined data + combined_tree.body.extend(tree_body) + combined_imports.update(imports) + # render tree as string of code + import_unique = list(set([ast.unparse(i) for i in combined_imports])) + import_unique.sort() + google, non_google = [], [] + for i in import_unique: + if "google" in i: + google.append(i) + else: + non_google.append(i) + import_str = "\n".join(non_google + [""] + google) + # append clean tree + header = """# Copyright 2024 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # + # This file is automatically generated by sync_surface_generator.py. Do not edit. + """ + full_code = f"{header}\n\n{import_str}\n\n{ast.unparse(combined_tree)}" + full_code = autoflake.fix_code(full_code, remove_all_unused_imports=True) + formatted_code = format_str(full_code, mode=FileMode()) + print(f"saving {async_class.cross_sync_class_name} to {output_file}...") + with open(output_file, "w") as f: + f.write(formatted_code) From 4cbeaba89b8cc736c96e50e0df970ce2f9f45988 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 19 Apr 2024 12:41:33 -0700 Subject: [PATCH 058/360] added sync implementations of certain functions --- google/cloud/bigtable/data/__init__.py | 12 +- google/cloud/bigtable/data/_async/client.py | 7 +- .../bigtable/data/_async/mutations_batcher.py | 4 + google/cloud/bigtable/data/_sync/_autogen.py | 383 ----- .../cloud/bigtable/data/_sync/_mutate_rows.py | 22 - .../cloud/bigtable/data/_sync/_read_rows.py | 22 - google/cloud/bigtable/data/_sync/client.py | 1233 ++++++++++++++++- .../cloud/bigtable/data/_sync/cross_sync.py | 56 + .../bigtable/data/_sync/mutations_batcher.py | 360 ++++- sync_surface_generator.py | 9 +- 10 files changed, 1557 insertions(+), 551 deletions(-) delete mode 100644 google/cloud/bigtable/data/_sync/_autogen.py delete mode 100644 google/cloud/bigtable/data/_sync/_mutate_rows.py delete mode 100644 google/cloud/bigtable/data/_sync/_read_rows.py diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index fd44fe86c..cdb7622b6 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -20,10 +20,10 @@ from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync -from google.cloud.bigtable.data._sync.client import BigtableDataClient -from google.cloud.bigtable.data._sync.client import Table +# from google.cloud.bigtable.data._sync.client import BigtableDataClient +# from google.cloud.bigtable.data._sync.client import Table -from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher +# from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.read_rows_query import RowRange @@ -53,9 +53,9 @@ __version__: str = package_version.__version__ __all__ = ( - "BigtableDataClient", - "Table", - "MutationsBatcher", + # "BigtableDataClient", + # "Table", + # "MutationsBatcher", "BigtableDataClientAsync", "TableAsync", "MutationsBatcherAsync", diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 02e0ce05e..752fa7047 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -82,12 +82,8 @@ from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery -if CrossSync.SyncImports: - from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport, PooledChannel - -@CrossSync.sync_output("google.cloud.bigtable.data._sync._autogen.BigtableDataClient_SyncGen") +@CrossSync.sync_output("google.cloud.bigtable.data._sync.client.BigtableDataClient") class BigtableDataClientAsync(ClientWithProject): def __init__( @@ -441,6 +437,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) +@CrossSync.sync_output("google.cloud.bigtable.data._sync.client.Table") class TableAsync: """ Main Data API surface diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 7faf10f24..a8f229083 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -32,6 +32,8 @@ ) from google.cloud.bigtable.data.mutations import Mutation +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + if TYPE_CHECKING: from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -40,6 +42,7 @@ _MB_SIZE = 1024 * 1024 +@CrossSync.sync_output("google.cloud.bigtable.data._sync.mutations_batcher._FlowControl") class _FlowControlAsync: """ Manages flow control for batched mutations. Mutations are registered against @@ -164,6 +167,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] yield mutations[start_idx:end_idx] +@CrossSync.sync_output("google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher") class MutationsBatcherAsync: """ Allows users to send batches using context manager API: diff --git a/google/cloud/bigtable/data/_sync/_autogen.py b/google/cloud/bigtable/data/_sync/_autogen.py deleted file mode 100644 index e537b3ac6..000000000 --- a/google/cloud/bigtable/data/_sync/_autogen.py +++ /dev/null @@ -1,383 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is automatically generated by sync_surface_generator.py. Do not edit. - - -from __future__ import annotations -from abc import ABC -from functools import partial -from grpc import Channel -from typing import Any -from typing import Optional -from typing import Set -from typing import cast -import asyncio -import concurrent.futures -import os -import random -import time -import warnings - -from google.api_core import client_options as client_options_lib -from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT -from google.cloud.bigtable.data import _helpers -from google.cloud.bigtable.data._async.client import TableAsync -from google.cloud.bigtable.data._sync.cross_sync import CrossSync -from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient -from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO -from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, -) -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel as AsyncPooledChannel, -) -from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest -from google.cloud.client import ClientWithProject -from google.cloud.environment_vars import BIGTABLE_EMULATOR -import google.auth._default -import google.auth.credentials - - -@CrossSync.sync_output( - "google.cloud.bigtable.data._sync._autogen.BigtableDataClient_SyncGen" -) -class BigtableDataClient_SyncGen(ClientWithProject, ABC): - def __init__( - self, - *, - project: str | None = None, - pool_size: int = 3, - credentials: google.auth.credentials.Credentials | None = None, - client_options: dict[str, Any] - | "google.api_core.client_options.ClientOptions" - | None = None, - ): - """ - Create a client instance for the Bigtable Data API - - Client should be created within an async context (running event loop) - - Warning: BigtableDataClientAsync is currently in preview, and is not - yet recommended for production use. - - Args: - project: the project which the client acts on behalf of. - If not passed, falls back to the default inferred - from the environment. - pool_size: The number of grpc channels to maintain - in the internal channel pool. - credentials: - Thehe OAuth2 Credentials to use for this - client. If not passed (and if no ``_http`` object is - passed), falls back to the default inferred from the - environment. - client_options (Optional[Union[dict, google.api_core.client_options.ClientOptions]]): - Client options used to set user options - on the client. API Endpoint should be set through client_options. - Raises: - - RuntimeError if called outside of an async context (no running event loop) - - ValueError if pool_size is less than 1 - """ - transport_str = f"bt-{self._client_version()}-{pool_size}" - transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) - BigtableClientMeta._transport_registry[transport_str] = transport - client_info = DEFAULT_CLIENT_INFO - client_info.client_library_version = self._client_version() - if type(client_options) is dict: - client_options = client_options_lib.from_dict(client_options) - client_options = cast( - Optional[client_options_lib.ClientOptions], client_options - ) - self._emulator_host = os.getenv(BIGTABLE_EMULATOR) - if self._emulator_host is not None: - if credentials is None: - credentials = google.auth.credentials.AnonymousCredentials() - if project is None: - project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT - ClientWithProject.__init__( - self, - credentials=credentials, - project=project, - client_options=client_options, - ) - self._gapic_client = BigtableAsyncClient( - transport=transport_str, - credentials=credentials, - client_options=client_options, - client_info=client_info, - ) - self._is_closed = asyncio.Event() - self.transport = cast( - PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport - ) - self._active_instances: Set[_helpers._WarmedInstanceKey] = set() - self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} - self._channel_init_time = time.monotonic() - self._channel_refresh_tasks: list[asyncio.Task[None]] = [] - self._executor = concurrent.futures.ThreadPoolExecutor() if not False else None - if self._emulator_host is not None: - warnings.warn( - "Connecting to Bigtable emulator at {}".format(self._emulator_host), - RuntimeWarning, - stacklevel=2, - ) - self.transport._grpc_channel = AsyncPooledChannel( - pool_size=pool_size, host=self._emulator_host, insecure=True - ) - self.transport._stubs = {} - self.transport._prep_wrapped_messages(client_info) - else: - try: - self._start_background_channel_refresh() - except RuntimeError: - warnings.warn( - f"{self.__class__.__name__} should be started in an asyncio event loop. Channel refresh will not be started", - RuntimeWarning, - stacklevel=2, - ) - - @staticmethod - def _client_version() -> str: - """Helper function to return the client version string for this client""" - if False: - return f"{google.cloud.bigtable.__version__}-data-async" - else: - return f"{google.cloud.bigtable.__version__}-data" - - def _start_background_channel_refresh(self) -> None: - """ - Starts a background task to ping and warm each channel in the pool - Raises: - - RuntimeError if not called in an asyncio event loop - """ - if ( - not self._channel_refresh_tasks - and (not self._emulator_host) - and (not self._is_closed.is_set()) - ): - for channel_idx in range(self.transport.pool_size): - refresh_task = CrossSync.create_task( - self._manage_channel, - channel_idx, - sync_executor=self._executor, - task_name=f"{self.__class__.__name__} channel refresh {channel_idx}", - ) - self._channel_refresh_tasks.append(refresh_task) - refresh_task.add_done_callback( - lambda _: self._channel_refresh_tasks.remove(refresh_task) - ) - - def close(self, timeout: float | None = None): - """Cancel all background tasks""" - self._is_closed.set() - for task in self._channel_refresh_tasks: - task.cancel() - self.transport.close() - if self._executor: - self._executor.shutdown(wait=False) - CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) - - def _ping_and_warm_instances( - self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None - ) -> list[BaseException | None]: - """ - Prepares the backend for requests on a channel - - Pings each Bigtable instance registered in `_active_instances` on the client - - Args: - - channel: grpc channel to warm - - instance_key: if provided, only warm the instance associated with the key - Returns: - - sequence of results or exceptions from the ping requests - """ - instance_list = ( - [instance_key] if instance_key is not None else self._active_instances - ) - ping_rpc = channel.unary_unary( - "/google.bigtable.v2.Bigtable/PingAndWarm", - request_serializer=PingAndWarmRequest.serialize, - ) - partial_list = [ - partial( - ping_rpc, - request={"name": instance_name, "app_profile_id": app_profile_id}, - metadata=[ - ( - "x-goog-request-params", - f"name={instance_name}&app_profile_id={app_profile_id}", - ) - ], - wait_for_ready=True, - ) - for (instance_name, table_name, app_profile_id) in instance_list - ] - result_list = CrossSync.gather_partials( - partial_list, return_exceptions=True, sync_executor=self._executor - ) - return [r or None for r in result_list] - - def _manage_channel( - self, - channel_idx: int, - refresh_interval_min: float = 60 * 35, - refresh_interval_max: float = 60 * 45, - grace_period: float = 60 * 10, - ) -> None: - """ - Background coroutine that periodically refreshes and warms a grpc channel - - The backend will automatically close channels after 60 minutes, so - `refresh_interval` + `grace_period` should be < 60 minutes - - Runs continuously until the client is closed - - Args: - channel_idx: index of the channel in the transport's channel pool - refresh_interval_min: minimum interval before initiating refresh - process in seconds. Actual interval will be a random value - between `refresh_interval_min` and `refresh_interval_max` - refresh_interval_max: maximum interval before initiating refresh - process in seconds. Actual interval will be a random value - between `refresh_interval_min` and `refresh_interval_max` - grace_period: time to allow previous channel to serve existing - requests before closing, in seconds - """ - sleep_fn = asyncio.sleep if False else self._is_closed.wait - first_refresh = self._channel_init_time + random.uniform( - refresh_interval_min, refresh_interval_max - ) - next_sleep = max(first_refresh - time.monotonic(), 0) - if next_sleep > 0: - channel = self.transport.channels[channel_idx] - self._ping_and_warm_instances(channel) - while not self._is_closed.is_set(): - sleep_fn(next_sleep) - if self._is_closed.is_set(): - break - new_channel = self.transport.grpc_channel._create_channel() - self._ping_and_warm_instances(new_channel) - start_timestamp = time.monotonic() - self.transport.replace_channel( - channel_idx, - grace=grace_period, - new_channel=new_channel, - event=self._is_closed, - ) - next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) - next_sleep = next_refresh - (time.monotonic() - start_timestamp) - - def _register_instance(self, instance_id: str, owner: TableAsync) -> None: - """ - Registers an instance with the client, and warms the channel pool - for the instance - The client will periodically refresh grpc channel pool used to make - requests, and new channels will be warmed for each registered instance - Channels will not be refreshed unless at least one instance is registered - - Args: - - instance_id: id of the instance to register. - - owner: table that owns the instance. Owners will be tracked in - _instance_owners, and instances will only be unregistered when all - owners call _remove_instance_registration - """ - instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _helpers._WarmedInstanceKey( - instance_name, owner.table_name, owner.app_profile_id - ) - self._instance_owners.setdefault(instance_key, set()).add(id(owner)) - if instance_name not in self._active_instances: - self._active_instances.add(instance_key) - if self._channel_refresh_tasks: - for channel in self.transport.channels: - self._ping_and_warm_instances(channel, instance_key) - else: - self._start_background_channel_refresh() - - def _remove_instance_registration( - self, instance_id: str, owner: TableAsync - ) -> bool: - """ - Removes an instance from the client's registered instances, to prevent - warming new channels for the instance - - If instance_id is not registered, or is still in use by other tables, returns False - - Args: - - instance_id: id of the instance to remove - - owner: table that owns the instance. Owners will be tracked in - _instance_owners, and instances will only be unregistered when all - owners call _remove_instance_registration - Returns: - - True if instance was removed - """ - instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _helpers._WarmedInstanceKey( - instance_name, owner.table_name, owner.app_profile_id - ) - owner_list = self._instance_owners.get(instance_key, set()) - try: - owner_list.remove(id(owner)) - if len(owner_list) == 0: - self._active_instances.remove(instance_key) - return True - except KeyError: - return False - - def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: - """ - Returns a table instance for making data API requests. All arguments are passed - directly to the TableAsync constructor. - - Args: - instance_id: The Bigtable instance ID to associate with this client. - instance_id is combined with the client's project to fully - specify the instance - table_id: The ID of the table. table_id is combined with the - instance_id and the client's project to fully specify the table - app_profile_id: The app profile to associate with requests. - https://cloud.google.com/bigtable/docs/app-profiles - default_read_rows_operation_timeout: The default timeout for read rows - operations, in seconds. If not set, defaults to 600 seconds (10 minutes) - default_read_rows_attempt_timeout: The default timeout for individual - read rows rpc requests, in seconds. If not set, defaults to 20 seconds - default_mutate_rows_operation_timeout: The default timeout for mutate rows - operations, in seconds. If not set, defaults to 600 seconds (10 minutes) - default_mutate_rows_attempt_timeout: The default timeout for individual - mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds - default_operation_timeout: The default timeout for all other operations, in - seconds. If not set, defaults to 60 seconds - default_attempt_timeout: The default timeout for all other individual rpc - requests, in seconds. If not set, defaults to 20 seconds - default_read_rows_retryable_errors: a list of errors that will be retried - if encountered during read_rows and related operations. - Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) - default_mutate_rows_retryable_errors: a list of errors that will be retried - if encountered during mutate_rows and related operations. - Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) - default_retryable_errors: a list of errors that will be retried if - encountered during all other operations. - Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) - """ - return TableAsync(self, instance_id, table_id, *args, **kwargs) - - def __enter__(self): - self._start_background_channel_refresh() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - self._gapic_client.__exit__(exc_type, exc_val, exc_tb) diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py deleted file mode 100644 index 1841c814b..000000000 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from __future__ import annotations - - -from google.cloud.bigtable.data._sync._autogen import _MutateRowsOperation_SyncGen - - -class _MutateRowsOperation(_MutateRowsOperation_SyncGen): - pass diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py deleted file mode 100644 index f43822ba4..000000000 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from __future__ import annotations - - -from google.cloud.bigtable.data._sync._autogen import _ReadRowsOperation_SyncGen - - -class _ReadRowsOperation(_ReadRowsOperation_SyncGen): - pass diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index b577f37e1..af34b4c45 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -12,99 +12,1216 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from __future__ import annotations +# This file is automatically generated by sync_surface_generator.py. Do not edit. -from typing import TYPE_CHECKING -import google.auth.credentials +from __future__ import annotations +from abc import ABC +from functools import partial +from grpc import Channel +from typing import Any +from typing import Optional +from typing import Sequence +from typing import Set +from typing import cast +import asyncio import concurrent.futures +import os +import random +import time +import warnings -from google.cloud.bigtable.data._sync._autogen import BigtableDataClient_SyncGen -from google.cloud.bigtable.data._sync._autogen import Table_SyncGen +from google.api_core import client_options as client_options_lib +from google.api_core import retry as retries +from google.api_core.exceptions import Aborted +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import ServiceUnavailable +from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT +from google.cloud.bigtable.data import _helpers +from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync +from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync +from google.cloud.bigtable.data._async.client import BigtableDataClientAsync +from google.cloud.bigtable.data._async.client import TableAsync +from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync +from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE +from google.cloud.bigtable.data._helpers import RowKeySamples +from google.cloud.bigtable.data._helpers import ShardedQuery +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data.exceptions import FailedQueryShardError +from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup +from google.cloud.bigtable.data.mutations import Mutation +from google.cloud.bigtable.data.mutations import RowMutationEntry +from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.row import Row +from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter +from google.cloud.bigtable.data.row_filters import RowFilter +from google.cloud.bigtable.data.row_filters import RowFilterChain +from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter +from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient +from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO +from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, +) +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel as AsyncPooledChannel, +) +from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest +from google.cloud.client import ClientWithProject +from google.cloud.environment_vars import BIGTABLE_EMULATOR +import google.auth._default +import google.auth.credentials -# import required so Table_SyncGen can create _MutateRowsOperation and _ReadRowsOperation -import google.cloud.bigtable.data._sync._read_rows # noqa: F401 -import google.cloud.bigtable.data._sync._mutate_rows # noqa: F401 -if TYPE_CHECKING: - from google.cloud.bigtable.data.row import Row +@CrossSync.sync_output("google.cloud.bigtable.data._sync.client.BigtableDataClient") +class BigtableDataClient(ClientWithProject, ABC): + def __init__( + self, + *, + project: str | None = None, + pool_size: int = 3, + credentials: google.auth.credentials.Credentials | None = None, + client_options: dict[str, Any] + | "google.api_core.client_options.ClientOptions" + | None = None, + ): + """ + Create a client instance for the Bigtable Data API + Client should be created within an async context (running event loop) -class BigtableDataClient(BigtableDataClient_SyncGen): - @property - def _executor(self) -> concurrent.futures.ThreadPoolExecutor: - if not hasattr(self, "_executor_instance"): - self._executor_instance = concurrent.futures.ThreadPoolExecutor() - return self._executor_instance + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + project: the project which the client acts on behalf of. + If not passed, falls back to the default inferred + from the environment. + pool_size: The number of grpc channels to maintain + in the internal channel pool. + credentials: + Thehe OAuth2 Credentials to use for this + client. If not passed (and if no ``_http`` object is + passed), falls back to the default inferred from the + environment. + client_options (Optional[Union[dict, google.api_core.client_options.ClientOptions]]): + Client options used to set user options + on the client. API Endpoint should be set through client_options. + Raises: + - RuntimeError if called outside of an async context (no running event loop) + - ValueError if pool_size is less than 1 + """ + transport_str = f"bt-{self._client_version()}-{pool_size}" + transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) + BigtableClientMeta._transport_registry[transport_str] = transport + client_info = DEFAULT_CLIENT_INFO + client_info.client_library_version = self._client_version() + if type(client_options) is dict: + client_options = client_options_lib.from_dict(client_options) + client_options = cast( + Optional[client_options_lib.ClientOptions], client_options + ) + self._emulator_host = os.getenv(BIGTABLE_EMULATOR) + if self._emulator_host is not None: + if credentials is None: + credentials = google.auth.credentials.AnonymousCredentials() + if project is None: + project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT + ClientWithProject.__init__( + self, + credentials=credentials, + project=project, + client_options=client_options, + ) + self._gapic_client = BigtableAsyncClient( + transport=transport_str, + credentials=credentials, + client_options=client_options, + client_info=client_info, + ) + self._is_closed = asyncio.Event() + self.transport = cast( + PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport + ) + self._active_instances: Set[_helpers._WarmedInstanceKey] = set() + self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} + self._channel_init_time = time.monotonic() + self._channel_refresh_tasks: list[asyncio.Task[None]] = [] + self._executor = concurrent.futures.ThreadPoolExecutor() if not False else None + if self._emulator_host is not None: + warnings.warn( + "Connecting to Bigtable emulator at {}".format(self._emulator_host), + RuntimeWarning, + stacklevel=2, + ) + self.transport._grpc_channel = AsyncPooledChannel( + pool_size=pool_size, host=self._emulator_host, insecure=True + ) + self.transport._stubs = {} + self.transport._prep_wrapped_messages(client_info) + else: + try: + self._start_background_channel_refresh() + except RuntimeError: + warnings.warn( + f"{self.__class__.__name__} should be started in an asyncio event loop. Channel refresh will not be started", + RuntimeWarning, + stacklevel=2, + ) @staticmethod def _client_version() -> str: - return f"{google.cloud.bigtable.__version__}-data" + """Helper function to return the client version string for this client""" + if False: + return f"{google.cloud.bigtable.__version__}-data-async" + else: + return f"{google.cloud.bigtable.__version__}-data" def _start_background_channel_refresh(self) -> None: + """ + Starts a background task to ping and warm each channel in the pool + Raises: + - RuntimeError if not called in an asyncio event loop + """ if ( not self._channel_refresh_tasks - and not self._emulator_host - and not self._is_closed.is_set() + and (not self._emulator_host) + and (not self._is_closed.is_set()) ): for channel_idx in range(self.transport.pool_size): - self._channel_refresh_tasks.append( - self._executor.submit(self._manage_channel, channel_idx) + refresh_task = create_task_sync( + self._manage_channel, + channel_idx, + sync_executor=self._executor, + task_name=f"{self.__class__.__name__} channel refresh {channel_idx}", + ) + self._channel_refresh_tasks.append(refresh_task) + refresh_task.add_done_callback( + lambda _: self._channel_refresh_tasks.remove(refresh_task) ) - def _execute_ping_and_warms(self, *fns) -> list[BaseException | None]: - futures_list = [self._executor.submit(f) for f in fns] - results_list: list[BaseException | None] = [] - for future in futures_list: - try: - future.result() - results_list.append(None) - except BaseException as e: - results_list.append(e) - return results_list + def close(self, timeout: float | None = None): + """Cancel all background tasks""" + self._is_closed.set() + for task in self._channel_refresh_tasks: + task.cancel() + self.transport.close() + if self._executor: + self._executor.shutdown(wait=False) + wait_sync(self._channel_refresh_tasks, timeout=timeout) + + def _ping_and_warm_instances( + self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None + ) -> list[BaseException | None]: + """ + Prepares the backend for requests on a channel + + Pings each Bigtable instance registered in `_active_instances` on the client - def close(self) -> None: + Args: + - channel: grpc channel to warm + - instance_key: if provided, only warm the instance associated with the key + Returns: + - sequence of results or exceptions from the ping requests """ - Close the client and all associated resources + instance_list = ( + [instance_key] if instance_key is not None else self._active_instances + ) + ping_rpc = channel.unary_unary( + "/google.bigtable.v2.Bigtable/PingAndWarm", + request_serializer=PingAndWarmRequest.serialize, + ) + partial_list = [ + partial( + ping_rpc, + request={"name": instance_name, "app_profile_id": app_profile_id}, + metadata=[ + ( + "x-goog-request-params", + f"name={instance_name}&app_profile_id={app_profile_id}", + ) + ], + wait_for_ready=True, + ) + for (instance_name, table_name, app_profile_id) in instance_list + ] + result_list = gather_partials_sync( + partial_list, return_exceptions=True, sync_executor=self._executor + ) + return [r or None for r in result_list] - This method should be called when the client is no longer needed. + def _manage_channel( + self, + channel_idx: int, + refresh_interval_min: float = 60 * 35, + refresh_interval_max: float = 60 * 45, + grace_period: float = 60 * 10, + ) -> None: """ - self._is_closed.set() - self._executor.shutdown(wait=True) - self._channel_refresh_tasks = [] - self.transport.close() + Background coroutine that periodically refreshes and warms a grpc channel + + The backend will automatically close channels after 60 minutes, so + `refresh_interval` + `grace_period` should be < 60 minutes + Runs continuously until the client is closed -class Table(Table_SyncGen): - def _register_with_client(self) -> concurrent.futures.Future[None]: - return self.client._executor.submit( - self.client._register_instance, self.instance_id, self + Args: + channel_idx: index of the channel in the transport's channel pool + refresh_interval_min: minimum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + refresh_interval_max: maximum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + grace_period: time to allow previous channel to serve existing + requests before closing, in seconds + """ + sleep_fn = asyncio.sleep if False else self._is_closed.wait + first_refresh = self._channel_init_time + random.uniform( + refresh_interval_min, refresh_interval_max ) + next_sleep = max(first_refresh - time.monotonic(), 0) + if next_sleep > 0: + channel = self.transport.channels[channel_idx] + self._ping_and_warm_instances(channel) + while not self._is_closed.is_set(): + sleep_fn(next_sleep) + if self._is_closed.is_set(): + break + new_channel = self.transport.grpc_channel._create_channel() + self._ping_and_warm_instances(new_channel) + start_timestamp = time.monotonic() + self.transport.replace_channel( + channel_idx, + grace=grace_period, + new_channel=new_channel, + event=self._is_closed, + ) + next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) + next_sleep = next_refresh - (time.monotonic() - start_timestamp) - def _shard_batch_helper( - self, kwargs_list: list[dict] - ) -> list[list[Row] | BaseException]: - futures_list = [ - self.client._executor.submit(self.read_rows, **kwargs) - for kwargs in kwargs_list - ] - results_list: list[list[Row] | BaseException] = [] - for future in futures_list: - if future.exception(): - results_list.append(future.exception()) + def _register_instance(self, instance_id: str, owner: TableAsync) -> None: + """ + Registers an instance with the client, and warms the channel pool + for the instance + The client will periodically refresh grpc channel pool used to make + requests, and new channels will be warmed for each registered instance + Channels will not be refreshed unless at least one instance is registered + + Args: + - instance_id: id of the instance to register. + - owner: table that owns the instance. Owners will be tracked in + _instance_owners, and instances will only be unregistered when all + owners call _remove_instance_registration + """ + instance_name = self._gapic_client.instance_path(self.project, instance_id) + instance_key = _helpers._WarmedInstanceKey( + instance_name, owner.table_name, owner.app_profile_id + ) + self._instance_owners.setdefault(instance_key, set()).add(id(owner)) + if instance_name not in self._active_instances: + self._active_instances.add(instance_key) + if self._channel_refresh_tasks: + for channel in self.transport.channels: + self._ping_and_warm_instances(channel, instance_key) else: - result = future.result() - if result is not None: - results_list.append(result) + self._start_background_channel_refresh() + + def _remove_instance_registration( + self, instance_id: str, owner: TableAsync + ) -> bool: + """ + Removes an instance from the client's registered instances, to prevent + warming new channels for the instance + + If instance_id is not registered, or is still in use by other tables, returns False + + Args: + - instance_id: id of the instance to remove + - owner: table that owns the instance. Owners will be tracked in + _instance_owners, and instances will only be unregistered when all + owners call _remove_instance_registration + Returns: + - True if instance was removed + """ + instance_name = self._gapic_client.instance_path(self.project, instance_id) + instance_key = _helpers._WarmedInstanceKey( + instance_name, owner.table_name, owner.app_profile_id + ) + owner_list = self._instance_owners.get(instance_key, set()) + try: + owner_list.remove(id(owner)) + if len(owner_list) == 0: + self._active_instances.remove(instance_key) + return True + except KeyError: + return False + + def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: + """ + Returns a table instance for making data API requests. All arguments are passed + directly to the TableAsync constructor. + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults to 20 seconds + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to 60 seconds + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + """ + return TableAsync(self, instance_id, table_id, *args, **kwargs) + + def __enter__(self): + self._start_background_channel_refresh() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + self._gapic_client.__exit__(exc_type, exc_val, exc_tb) + + +@CrossSync.sync_output("google.cloud.bigtable.data._sync.client.Table") +class Table(ABC): + """ + Main Data API surface + + Table object maintains table_id, and app_profile_id context, and passes them with + each call + """ + + def __init__( + self, + client: BigtableDataClientAsync, + instance_id: str, + table_id: str, + app_profile_id: str | None = None, + *, + default_read_rows_operation_timeout: float = 600, + default_read_rows_attempt_timeout: float | None = 20, + default_mutate_rows_operation_timeout: float = 600, + default_mutate_rows_attempt_timeout: float | None = 60, + default_operation_timeout: float = 60, + default_attempt_timeout: float | None = 20, + default_read_rows_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + Aborted, + ), + default_mutate_rows_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + ), + default_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + ), + ): + """ + Initialize a Table instance + + Must be created within an async context (running event loop) + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults to 20 seconds + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to 60 seconds + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + Raises: + - RuntimeError if called outside of an async context (no running event loop) + """ + _helpers._validate_timeouts( + default_operation_timeout, default_attempt_timeout, allow_none=True + ) + _helpers._validate_timeouts( + default_read_rows_operation_timeout, + default_read_rows_attempt_timeout, + allow_none=True, + ) + _helpers._validate_timeouts( + default_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout, + allow_none=True, + ) + self.client = client + self.instance_id = instance_id + self.instance_name = self.client._gapic_client.instance_path( + self.client.project, instance_id + ) + self.table_id = table_id + self.table_name = self.client._gapic_client.table_path( + self.client.project, instance_id, table_id + ) + self.app_profile_id = app_profile_id + self.default_operation_timeout = default_operation_timeout + self.default_attempt_timeout = default_attempt_timeout + self.default_read_rows_operation_timeout = default_read_rows_operation_timeout + self.default_read_rows_attempt_timeout = default_read_rows_attempt_timeout + self.default_mutate_rows_operation_timeout = ( + default_mutate_rows_operation_timeout + ) + self.default_mutate_rows_attempt_timeout = default_mutate_rows_attempt_timeout + self.default_read_rows_retryable_errors = ( + default_read_rows_retryable_errors or () + ) + self.default_mutate_rows_retryable_errors = ( + default_mutate_rows_retryable_errors or () + ) + self.default_retryable_errors = default_retryable_errors or () + try: + self._register_instance_future = create_task_sync( + self.client._register_instance, + self.instance_id, + self, + sync_executor=self.client._executor, + ) + except RuntimeError as e: + raise RuntimeError( + f"{self.__class__.__name__} must be created within an async event loop context." + ) from e + + def read_rows_stream( + self, + query: ReadRowsQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> Iterable[Row]: + """ + Read a set of rows from the table, based on the specified query. + Returns an iterator to asynchronously stream back row data. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - query: contains details about which rows to return + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors + Returns: + - an asynchronous iterator that yields rows returned by the query + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + (operation_timeout, attempt_timeout) = _helpers._get_timeouts( + operation_timeout, attempt_timeout, self + ) + retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) + row_merger = _ReadRowsOperationAsync( + query, + self, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_exceptions=retryable_excs, + ) + return row_merger.start_operation() + + def read_rows( + self, + query: ReadRowsQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> list[Row]: + """ + Read a set of rows from the table, based on the specified query. + Retruns results as a list of Row objects when the request is complete. + For streamed results, use read_rows_stream. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - query: contains details about which rows to return + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + If None, defaults to the Table's default_read_rows_attempt_timeout, + or the operation_timeout if that is also None. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + - a list of Rows returned by the query + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + row_generator = self.read_rows_stream( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + return [row for row in row_generator] + + def read_row( + self, + row_key: str | bytes, + *, + row_filter: RowFilter | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> Row | None: + """ + Read a single row from the table, based on the specified key. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - query: contains details about which rows to return + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + - a Row object if the row exists, otherwise None + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + if row_key is None: + raise ValueError("row_key must be string or bytes") + query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) + results = self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + if len(results) == 0: + return None + return results[0] + + def read_rows_sharded( + self, + sharded_query: ShardedQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> list[Row]: + """ + Runs a sharded query in parallel, then return the results in a single list. + Results will be returned in the order of the input queries. + + This function is intended to be run on the results on a query.shard() call: + + ``` + table_shard_keys = await table.sample_row_keys() + query = ReadRowsQuery(...) + shard_queries = query.shard(table_shard_keys) + results = await table.read_rows_sharded(shard_queries) + ``` + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - sharded_query: a sharded query to execute + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Raises: + - ShardedReadRowsExceptionGroup: if any of the queries failed + - ValueError: if the query_list is empty + """ + if not sharded_query: + raise ValueError("empty sharded_query") + (operation_timeout, attempt_timeout) = _helpers._get_timeouts( + operation_timeout, attempt_timeout, self + ) + timeout_generator = _helpers._attempt_timeout_generator( + operation_timeout, operation_timeout + ) + batched_queries = [ + sharded_query[i : i + _helpers._CONCURRENCY_LIMIT] + for i in range(0, len(sharded_query), _helpers._CONCURRENCY_LIMIT) + ] + results_list = [] + error_dict = {} + shard_idx = 0 + for batch in batched_queries: + batch_operation_timeout = next(timeout_generator) + batch_partial_list = [ + partial( + self.read_rows, + query=query, + operation_timeout=batch_operation_timeout, + attempt_timeout=min(attempt_timeout, batch_operation_timeout), + retryable_errors=retryable_errors, + ) + for query in batch + ] + batch_result = gather_partials_sync( + batch_partial_list, + return_exceptions=True, + sync_executor=self.client._executor, + ) + for result in batch_result: + if isinstance(result, Exception): + error_dict[shard_idx] = result + elif isinstance(result, BaseException): + raise result + else: + results_list.extend(result) + shard_idx += 1 + if error_dict: + raise ShardedReadRowsExceptionGroup( + [ + FailedQueryShardError(idx, sharded_query[idx], e) + for (idx, e) in error_dict.items() + ], + results_list, + len(sharded_query), + ) return results_list + def row_exists( + self, + row_key: str | bytes, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> bool: + """ + Return a boolean indicating whether the specified row exists in the table. + uses the filters: chain(limit cells per row = 1, strip value) + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - row_key: the key of the row to check + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + - a bool indicating whether the row exists + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + if row_key is None: + raise ValueError("row_key must be string or bytes") + strip_filter = StripValueTransformerFilter(flag=True) + limit_filter = CellsRowLimitFilter(1) + chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) + query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) + results = self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + return len(results) > 0 + + def sample_row_keys( + self, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> RowKeySamples: + """ + Return a set of RowKeySamples that delimit contiguous sections of the table of + approximately equal size + + RowKeySamples output can be used with ReadRowsQuery.shard() to create a sharded query that + can be parallelized across multiple backend nodes read_rows and read_rows_stream + requests will call sample_row_keys internally for this purpose when sharding is enabled + + RowKeySamples is simply a type alias for list[tuple[bytes, int]]; a list of + row_keys, along with offset positions in the table + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget.i + Defaults to the Table's default_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_retryable_errors. + Returns: + - a set of RowKeySamples the delimit contiguous sections of the table + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + - GoogleAPIError: raised if the request encounters an unrecoverable error + """ + (operation_timeout, attempt_timeout) = _helpers._get_timeouts( + operation_timeout, attempt_timeout, self + ) + attempt_timeout_gen = _helpers._attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) + predicate = retries.if_exception_type(*retryable_excs) + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) + + def execute_rpc(): + results = self.client._gapic_client.sample_row_keys( + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=next(attempt_timeout_gen), + metadata=metadata, + retry=None, + ) + return [(s.row_key, s.offset_bytes) for s in results] + + return retries.retry_target_async( + execute_rpc, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_helpers._retry_exception_factory, + ) + + def mutations_batcher( + self, + *, + flush_interval: float | None = 5, + flush_limit_mutation_count: int | None = 1000, + flush_limit_bytes: int = 20 * _MB_SIZE, + flow_control_max_mutation_count: int = 100000, + flow_control_max_bytes: int = 100 * _MB_SIZE, + batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ) -> MutationsBatcherAsync: + """ + Returns a new mutations batcher instance. + + Can be used to iteratively add mutations that are flushed as a group, + to avoid excess network calls + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - flush_interval: Automatically flush every flush_interval seconds. If None, + a table default will be used + - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count + mutations are added across all entries. If None, this limit is ignored. + - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. + - flow_control_max_mutation_count: Maximum number of inflight mutations. + - flow_control_max_bytes: Maximum number of inflight bytes. + - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + Defaults to the Table's default_mutate_rows_operation_timeout + - batch_attempt_timeout: timeout for each individual request, in seconds. + Defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to batch_operation_timeout. + - batch_retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors. + Returns: + - a MutationsBatcherAsync context manager that can batch requests + """ + return MutationsBatcherAsync( + self, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_mutation_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=batch_operation_timeout, + batch_attempt_timeout=batch_attempt_timeout, + batch_retryable_errors=batch_retryable_errors, + ) + + def mutate_row( + self, + row_key: str | bytes, + mutations: list[Mutation] | Mutation, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ): + """ + Mutates a row atomically. + + Cells already present in the row are left unchanged unless explicitly changed + by ``mutation``. + + Idempotent operations (i.e, all mutations have an explicit timestamp) will be + retried on server failure. Non-idempotent operations will not. + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - row_key: the row to apply mutations to + - mutations: the set of mutations to apply to the row + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Only idempotent mutations will be retried. Defaults to the Table's + default_retryable_errors. + Raises: + - DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing all + GoogleAPIError exceptions from any retries that failed + - GoogleAPIError: raised on non-idempotent operations that cannot be + safely retried. + - ValueError if invalid arguments are provided + """ + (operation_timeout, attempt_timeout) = _helpers._get_timeouts( + operation_timeout, attempt_timeout, self + ) + if not mutations: + raise ValueError("No mutations provided") + mutations_list = mutations if isinstance(mutations, list) else [mutations] + if all((mutation.is_idempotent() for mutation in mutations_list)): + predicate = retries.if_exception_type( + *_helpers._get_retryable_errors(retryable_errors, self) + ) + else: + predicate = retries.if_exception_type() + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + target = partial( + self.client._gapic_client.mutate_row, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + mutations=[mutation._to_pb() for mutation in mutations_list], + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=attempt_timeout, + metadata=_helpers._make_metadata(self.table_name, self.app_profile_id), + retry=None, + ) + return retries.retry_target_async( + target, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_helpers._retry_exception_factory, + ) + + def bulk_mutate_rows( + self, + mutation_entries: list[RowMutationEntry], + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ): + """ + Applies mutations for multiple rows in a single batched request. + + Each individual RowMutationEntry is applied atomically, but separate entries + may be applied in arbitrary order (even for entries targetting the same row) + In total, the row_mutations can contain at most 100000 individual mutations + across all entries + + Idempotent entries (i.e., entries with mutations with explicit timestamps) + will be retried on failure. Non-idempotent will not, and will reported in a + raised exception group + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - mutation_entries: the batches of mutations to apply + Each entry will be applied atomically, but entries will be applied + in arbitrary order + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_mutate_rows_operation_timeout + - attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors + Raises: + - MutationsExceptionGroup if one or more mutations fails + Contains details about any failed entries in .exceptions + - ValueError if invalid arguments are provided + """ + (operation_timeout, attempt_timeout) = _helpers._get_timeouts( + operation_timeout, attempt_timeout, self + ) + retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) + operation = _MutateRowsOperationAsync( + self.client._gapic_client, + self, + mutation_entries, + operation_timeout, + attempt_timeout, + retryable_exceptions=retryable_excs, + ) + operation.start() + + def check_and_mutate_row( + self, + row_key: str | bytes, + predicate: RowFilter | None, + *, + true_case_mutations: Mutation | list[Mutation] | None = None, + false_case_mutations: Mutation | list[Mutation] | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> bool: + """ + Mutates a row atomically based on the output of a predicate filter + + Non-idempotent operation: will not be retried + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - row_key: the key of the row to mutate + - predicate: the filter to be applied to the contents of the specified row. + Depending on whether or not any results are yielded, + either true_case_mutations or false_case_mutations will be executed. + If None, checks that the row contains any values at all. + - true_case_mutations: + Changes to be atomically applied to the specified row if + predicate yields at least one cell when + applied to row_key. Entries are applied in order, + meaning that earlier mutations can be masked by later + ones. Must contain at least one entry if + false_case_mutations is empty, and at most 100000. + - false_case_mutations: + Changes to be atomically applied to the specified row if + predicate_filter does not yield any cells when + applied to row_key. Entries are applied in order, + meaning that earlier mutations can be masked by later + ones. Must contain at least one entry if + `true_case_mutations is empty, and at most 100000. + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will not be retried. Defaults to the Table's default_operation_timeout + Returns: + - bool indicating whether the predicate was true or false + Raises: + - GoogleAPIError exceptions from grpc call + """ + (operation_timeout, _) = _helpers._get_timeouts(operation_timeout, None, self) + if true_case_mutations is not None and ( + not isinstance(true_case_mutations, list) + ): + true_case_mutations = [true_case_mutations] + true_case_list = [m._to_pb() for m in true_case_mutations or []] + if false_case_mutations is not None and ( + not isinstance(false_case_mutations, list) + ): + false_case_mutations = [false_case_mutations] + false_case_list = [m._to_pb() for m in false_case_mutations or []] + metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) + result = self.client._gapic_client.check_and_mutate_row( + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) + return result.predicate_matched + + def read_modify_write_row( + self, + row_key: str | bytes, + rules: ReadModifyWriteRule | list[ReadModifyWriteRule], + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> Row: + """ + Reads and modifies a row atomically according to input ReadModifyWriteRules, + and returns the contents of all modified cells + + The new value for the timestamp is the greater of the existing timestamp or + the current server time. + + Non-idempotent operation: will not be retried + + Warning: BigtableDataClientAsync is currently in preview, and is not + yet recommended for production use. + + Args: + - row_key: the key of the row to apply read/modify/write rules to + - rules: A rule or set of rules to apply to the row. + Rules are applied in order, meaning that earlier rules will affect the + results of later ones. + - operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will not be retried. + Defaults to the Table's default_operation_timeout. + Returns: + - Row: containing cell data that was modified as part of the + operation + Raises: + - GoogleAPIError exceptions from grpc call + - ValueError if invalid arguments are provided + """ + (operation_timeout, _) = _helpers._get_timeouts(operation_timeout, None, self) + if operation_timeout <= 0: + raise ValueError("operation_timeout must be greater than 0") + if rules is not None and (not isinstance(rules, list)): + rules = [rules] + if not rules: + raise ValueError("rules must contain at least one item") + metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) + result = self.client._gapic_client.read_modify_write_row( + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) + return Row._from_pb(result.row) + + def close(self): + """Called to close the Table instance and release any resources held by it.""" + if self._register_instance_future: + self._register_instance_future.cancel() + self.client._remove_instance_registration(self.instance_id, self) + def __enter__(self): """ - Implement context manager protocol + Implement async context manager protocol Ensure registration task has time to run, so that grpc channels will be warmed for the specified instance """ if self._register_instance_future: - self._register_instance_future.result() + self._register_instance_future return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Implement async context manager protocol + + Unregister this instance with the client, so that + grpc channels will no longer be warmed + """ + self.close() diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index d648d601a..c1d27d4b5 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -1,5 +1,21 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import asyncio import sys +import concurrent.futures class CrossSync: @@ -43,6 +59,26 @@ async def gather_partials(partial_list, return_exceptions=False, sync_executor=N awaitable_list = [partial() for partial in partial_list] return await asyncio.gather(*awaitable_list, return_exceptions=return_exceptions) + @staticmethod + def gather_partials_sync(partial_list, return_exceptions=False, sync_executor=None): + if not partial_list: + return [] + if not sync_executor: + raise ValueError("sync_executor is required for sync version") + futures_list = [ + sync_executor.submit(partial) for partial in partial_list + ] + results_list = [] + for future in futures_list: + if future.exception(): + if return_exceptions: + results_list.append(future.exception()) + else: + raise future.exception() + else: + results_list.append(future.result()) + return results_list + @staticmethod async def wait(futures, timeout=None): """ @@ -52,6 +88,15 @@ async def wait(futures, timeout=None): return set(), set() return await asyncio.wait(futures, timeout=timeout) + @staticmethod + def wait_sync(futures, timeout=None): + """ + abstraction over asyncio.wait + """ + if not futures: + return set(), set() + return concurrent.futures.wait(futures, timeout=timeout) + @staticmethod def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): """ @@ -63,3 +108,14 @@ def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): if task_name and sys.version_info >= (3, 8): task.set_name(task_name) return task + + @staticmethod + def create_task_sync(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): + """ + abstraction over asyncio.create_task. Sync version implemented with threadpool executor + + sync_executor: ThreadPoolExecutor to use for sync operations. Ignored in async version + """ + if not sync_executor: + raise ValueError("sync_executor is required for sync version") + return sync_executor.submit(fn, *fn_args, **fn_kwargs) diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index ee5a3aac7..81e46c847 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -12,83 +12,337 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from __future__ import annotations +# This file is automatically generated by sync_surface_generator.py. Do not edit. -from typing import TYPE_CHECKING -import concurrent.futures +from __future__ import annotations +from abc import ABC +from collections import deque +from typing import Any +from typing import Sequence +import asyncio import atexit +import warnings -from google.cloud.bigtable.data._sync._autogen import _FlowControl_SyncGen -from google.cloud.bigtable.data._sync._autogen import MutationsBatcher_SyncGen +from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync +from google.cloud.bigtable.data._async.client import TableAsync +from google.cloud.bigtable.data._async.mutations_batcher import _FlowControlAsync +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _get_retryable_errors +from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data.exceptions import FailedMutationEntryError +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup +from google.cloud.bigtable.data.mutations import Mutation +from google.cloud.bigtable.data.mutations import RowMutationEntry -# import required so MutationsBatcher_SyncGen can create _MutateRowsOperation -import google.cloud.bigtable.data._sync._mutate_rows # noqa: F401 -if TYPE_CHECKING: - from google.cloud.bigtable.data.exceptions import FailedMutationEntryError +@CrossSync.sync_output( + "google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher" +) +class MutationsBatcher(ABC): + """ + Allows users to send batches using context manager API: + Runs mutate_row, mutate_rows, and check_and_mutate_row internally, combining + to use as few network requests as required -class _FlowControl(_FlowControl_SyncGen): - pass + Flushes: + - every flush_interval seconds + - after queue reaches flush_count in quantity + - after queue reaches flush_size_bytes in storage size + - when batcher is closed or destroyed + async with table.mutations_batcher() as batcher: + for i in range(10): + batcher.add(row, mut) + """ -class MutationsBatcher(MutationsBatcher_SyncGen): - @property - def _executor(self): + def __init__( + self, + table: "TableAsync", + *, + flush_interval: float | None = 5, + flush_limit_mutation_count: int | None = 1000, + flush_limit_bytes: int = 20 * _MB_SIZE, + flow_control_max_mutation_count: int = 100000, + flow_control_max_bytes: int = 100 * _MB_SIZE, + batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ): """ - Return a ThreadPoolExecutor for background tasks + Args: + - table: Table to preform rpc calls + - flush_interval: Automatically flush every flush_interval seconds. + If None, no time-based flushing is performed. + - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count + mutations are added across all entries. If None, this limit is ignored. + - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. + - flow_control_max_mutation_count: Maximum number of inflight mutations. + - flow_control_max_bytes: Maximum number of inflight bytes. + - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_operation_timeout. + - batch_attempt_timeout: timeout for each individual request, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to batch_operation_timeout. + - batch_retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors. """ - if not hasattr(self, "_threadpool"): - self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=8) - return self._threadpool + (self._operation_timeout, self._attempt_timeout) = _get_timeouts( + batch_operation_timeout, batch_attempt_timeout, table + ) + self._retryable_errors: list[type[Exception]] = _get_retryable_errors( + batch_retryable_errors, table + ) + self._closed: asyncio.Event = asyncio.Event() + self._table = table + self._staged_entries: list[RowMutationEntry] = [] + (self._staged_count, self._staged_bytes) = (0, 0) + self._flow_control = _FlowControlAsync( + flow_control_max_mutation_count, flow_control_max_bytes + ) + self._flush_limit_bytes = flush_limit_bytes + self._flush_limit_count = ( + flush_limit_mutation_count + if flush_limit_mutation_count is not None + else float("inf") + ) + self._flush_timer = self._create_bg_task(self._timer_routine, flush_interval) + self._flush_jobs: set[asyncio.Future[None]] = set() + self._entries_processed_since_last_raise: int = 0 + self._exceptions_since_last_raise: int = 0 + self._exception_list_limit: int = 10 + self._oldest_exceptions: list[Exception] = [] + self._newest_exceptions: deque[Exception] = deque( + maxlen=self._exception_list_limit + ) + atexit.register(self._on_exit) - def close(self): + def _timer_routine(self, interval: float | None) -> None: + """ + Triggers new flush tasks every `interval` seconds + Ends when the batcher is closed + """ + if not interval or interval <= 0: + return None + while not self._closed.is_set(): + try: + asyncio.wait_for(self._closed.wait(), timeout=interval) + except asyncio.TimeoutError: + pass + if not self._closed.is_set() and self._staged_entries: + self._schedule_flush() + + def append(self, mutation_entry: RowMutationEntry): """ - Flush queue and clean up resources + Add a new set of mutations to the internal queue + + TODO: return a future to track completion of this entry + + Args: + - mutation_entry: new entry to add to flush queue + Raises: + - RuntimeError if batcher is closed + - ValueError if an invalid mutation type is added + """ + if self._closed.is_set(): + raise RuntimeError("Cannot append to closed MutationsBatcher") + if isinstance(mutation_entry, Mutation): + raise ValueError( + f"invalid mutation type: {type(mutation_entry).__name__}. Only RowMutationEntry objects are supported by batcher" + ) + self._staged_entries.append(mutation_entry) + self._staged_count += len(mutation_entry.mutations) + self._staged_bytes += mutation_entry.size() + if ( + self._staged_count >= self._flush_limit_count + or self._staged_bytes >= self._flush_limit_bytes + ): + self._schedule_flush() + asyncio.sleep(0) + + def _schedule_flush(self) -> asyncio.Future[None] | None: + """Update the flush task to include the latest staged entries""" + if self._staged_entries: + (entries, self._staged_entries) = (self._staged_entries, []) + (self._staged_count, self._staged_bytes) = (0, 0) + new_task = self._create_bg_task(self._flush_internal, entries) + if not new_task.done(): + self._flush_jobs.add(new_task) + new_task.add_done_callback(self._flush_jobs.remove) + return new_task + return None + + def _flush_internal(self, new_entries: list[RowMutationEntry]): + """ + Flushes a set of mutations to the server, and updates internal state + + Args: + - new_entries: list of RowMutationEntry objects to flush + """ + in_process_requests: list[asyncio.Future[list[FailedMutationEntryError]]] = [] + for batch in self._flow_control.add_to_flow(new_entries): + batch_task = self._create_bg_task(self._execute_mutate_rows, batch) + in_process_requests.append(batch_task) + found_exceptions = self._wait_for_batch_results(*in_process_requests) + self._entries_processed_since_last_raise += len(new_entries) + self._add_exceptions(found_exceptions) + + def _execute_mutate_rows( + self, batch: list[RowMutationEntry] + ) -> list[FailedMutationEntryError]: + """ + Helper to execute mutation operation on a batch + + Args: + - batch: list of RowMutationEntry objects to send to server + - timeout: timeout in seconds. Used as operation_timeout and attempt_timeout. + If not given, will use table defaults + Returns: + - list of FailedMutationEntryError objects for mutations that failed. + FailedMutationEntryError objects will not contain index information + """ + try: + operation = _MutateRowsOperationAsync( + self._table.client._gapic_client, + self._table, + batch, + operation_timeout=self._operation_timeout, + attempt_timeout=self._attempt_timeout, + retryable_exceptions=self._retryable_errors, + ) + operation.start() + except MutationsExceptionGroup as e: + for subexc in e.exceptions: + subexc.index = None + return list(e.exceptions) + finally: + self._flow_control.remove_from_flow(batch) + return [] + + def _add_exceptions(self, excs: list[Exception]): + """ + Add new list of exceptions to internal store. To avoid unbounded memory, + the batcher will store the first and last _exception_list_limit exceptions, + and discard any in between. + """ + self._exceptions_since_last_raise += len(excs) + if excs and len(self._oldest_exceptions) < self._exception_list_limit: + addition_count = self._exception_list_limit - len(self._oldest_exceptions) + self._oldest_exceptions.extend(excs[:addition_count]) + excs = excs[addition_count:] + if excs: + self._newest_exceptions.extend(excs[-self._exception_list_limit :]) + + def _raise_exceptions(self): """ + Raise any unreported exceptions from background flush operations + + Raises: + - MutationsExceptionGroup with all unreported exceptions + """ + if self._oldest_exceptions or self._newest_exceptions: + (oldest, self._oldest_exceptions) = (self._oldest_exceptions, []) + newest = list(self._newest_exceptions) + self._newest_exceptions.clear() + (entry_count, self._entries_processed_since_last_raise) = ( + self._entries_processed_since_last_raise, + 0, + ) + (exc_count, self._exceptions_since_last_raise) = ( + self._exceptions_since_last_raise, + 0, + ) + raise MutationsExceptionGroup.from_truncated_lists( + first_list=oldest, + last_list=newest, + total_excs=exc_count, + entry_count=entry_count, + ) + + def __enter__(self): + """For context manager API""" + return self + + def __exit__(self, exc_type, exc, tb): + """For context manager API""" + self.close() + + @property + def closed(self) -> bool: + """ + Returns: + - True if the batcher is closed, False otherwise + """ + return self._closed.is_set() + + def close(self): + """Flush queue and clean up resources""" self._closed.set() - # attempt cancel timer if not started self._flush_timer.cancel() self._schedule_flush() - with self._executor: - self._executor.shutdown(wait=True) + if self._flush_jobs: + asyncio.gather(*self._flush_jobs, return_exceptions=True) + try: + self._flush_timer + except asyncio.CancelledError: + pass atexit.unregister(self._on_exit) - # raise unreported exceptions self._raise_exceptions() - def _create_bg_task(self, func, *args, **kwargs): - return self._executor.submit(func, *args, **kwargs) + def _on_exit(self): + """Called when program is exited. Raises warning if unflushed mutations remain""" + if not self._closed.is_set() and self._staged_entries: + warnings.warn( + f"MutationsBatcher for table {self._table.table_name} was not closed. {len(self._staged_entries)} Unflushed mutations will not be sent to the server." + ) + + @staticmethod + def _create_bg_task(func, *args, **kwargs) -> asyncio.Future[Any]: + """ + Create a new background task, and return a future + + This method wraps asyncio to make it easier to maintain subclasses + with different concurrency models. + + Args: + - func: function to execute in background task + - *args: positional arguments to pass to func + - **kwargs: keyword arguments to pass to func + Returns: + - Future object representing the background task + """ + return asyncio.create_task(func(*args, **kwargs)) @staticmethod def _wait_for_batch_results( - *tasks: concurrent.futures.Future[list[FailedMutationEntryError]] - | concurrent.futures.Future[None], + *tasks: asyncio.Future[list[FailedMutationEntryError]] | asyncio.Future[None], ) -> list[Exception]: - if not tasks: - return [] - exceptions: list[Exception] = [] - for task in tasks: - try: - exc_list = task.result() - if exc_list: - for exc in exc_list: - # strip index information - exc.index = None - exceptions.extend(exc_list) - except Exception as e: - exceptions.append(e) - return exceptions - - def _timer_routine(self, interval: float | None) -> None: """ - Triggers new flush tasks every `interval` seconds - Ends when the batcher is closed + Takes in a list of futures representing _execute_mutate_rows tasks, + waits for them to complete, and returns a list of errors encountered. + + Args: + - *tasks: futures representing _execute_mutate_rows or _flush_internal tasks + Returns: + - list of Exceptions encountered by any of the tasks. Errors are expected + to be FailedMutationEntryError, representing a failed mutation operation. + If a task fails with a different exception, it will be included in the + output list. Successful tasks will not be represented in the output list. """ - if not interval or interval <= 0: - return None - while not self._closed.is_set(): - # wait until interval has passed, or until closed - self._closed.wait(timeout=interval) - if not self._closed.is_set() and self._staged_entries: - self._schedule_flush() + if not tasks: + return [] + all_results = asyncio.gather(*tasks, return_exceptions=True) + found_errors = [] + for result in all_results: + if isinstance(result, Exception): + found_errors.append(result) + elif isinstance(result, BaseException): + raise result + elif result: + for e in result: + e.index = None + found_errors.extend(result) + return found_errors diff --git a/sync_surface_generator.py b/sync_surface_generator.py index aae969b30..76619187f 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -57,6 +57,9 @@ def __init__(self, *, name=None, cross_sync_replacements=None, text_replacements "Task": "concurrent.futures.Future", "Event": "threading.Event", "is_async": "False", + "gather_partials": "gather_partials_sync", + "wait": "wait_sync", + "create_task": "create_task_sync", } self.text_replacements = text_replacements or {} self.drop_methods = drop_methods or [] @@ -127,9 +130,11 @@ def visit_AsyncFunctionDef(self, node): if hasattr(node, "decorator_list"): # TODO: make generic is_asyncio_decorator = lambda d: all(x in ast.dump(d) for x in ["pytest", "mark", "asyncio"]) + is_cross_sync_decorator = lambda d: all(x in ast.dump(d) for x in ["CrossSync", "sync_output"]) node.decorator_list = [ - d for d in node.decorator_list if not is_asyncio_decorator(d) + d for d in node.decorator_list if not is_asyncio_decorator(d) and not is_cross_sync_decorator(d) ] + # visit string type annotations for arg in node.args.args: if arg.annotation: @@ -399,7 +404,7 @@ def transform_from_config(config_dict: dict): # with open(save_path, "w") as f: # f.write(code) # find all classes in the library - import google.cloud.bigtable.data as data_lib + import google.cloud.bigtable.data._async as data_lib lib_classes = inspect.getmembers(data_lib, inspect.isclass) # keep only those with CrossSync annotation enabled_classes = [c[1] for c in lib_classes if hasattr(c[1], "cross_sync_enabled")] From 86c6cd961af26270f0c2a570973cc7ae79f20992 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 30 May 2024 15:18:37 -0700 Subject: [PATCH 059/360] added crosssync to mutations batcher --- .../bigtable/data/_async/mutations_batcher.py | 98 ++++++++-------- google/cloud/bigtable/data/_sync/client.py | 28 ++--- .../cloud/bigtable/data/_sync/cross_sync.py | 31 +++++ .../bigtable/data/_sync/mutations_batcher.py | 109 +++++++++--------- sync_surface_generator.py | 8 +- 5 files changed, 148 insertions(+), 126 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index a8f229083..c7f0660ff 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -72,7 +72,7 @@ def __init__( raise ValueError("max_mutation_count must be greater than 0") if self._max_mutation_bytes < 1: raise ValueError("max_mutation_bytes must be greater than 0") - self._capacity_condition = asyncio.Condition() + self._capacity_condition = CrossSync.Condition() self._in_flight_mutation_count = 0 self._in_flight_mutation_bytes = 0 @@ -144,7 +144,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] while end_idx < len(mutations): next_entry = mutations[end_idx] next_size = next_entry.size() - next_count = len(next_entry.mutations) + next_count= len(next_entry.mutations) if ( self._has_capacity(next_count, next_size) # make sure not to exceed per-request mutation count limits @@ -225,7 +225,7 @@ def __init__( batch_retryable_errors, table ) - self._closed: asyncio.Event = asyncio.Event() + self._closed = CrossSync.Event() self._table = table self._staged_entries: list[RowMutationEntry] = [] self._staged_count, self._staged_bytes = 0, 0 @@ -238,8 +238,8 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - self._flush_timer = self._create_bg_task(self._timer_routine, flush_interval) - self._flush_jobs: set[asyncio.Future[None]] = set() + self._flush_timer = CrossSync.create_task(self._timer_routine, flush_interval, sync_executor=self._sync_executor) + self._flush_jobs: set[CrossSync.Future[None]] = set() # MutationExceptionGroup reports number of successful entries along with failures self._entries_processed_since_last_raise: int = 0 self._exceptions_since_last_raise: int = 0 @@ -251,6 +251,11 @@ def __init__( ) # clean up on program exit atexit.register(self._on_exit) + # in sync mode, use a threadpool executor for background tasks + if not CrossSync.is_async: + self._sync_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) + else: + self._sync_executor = None async def _timer_routine(self, interval: float | None) -> None: """ @@ -261,10 +266,7 @@ async def _timer_routine(self, interval: float | None) -> None: return None while not self._closed.is_set(): # wait until interval has passed, or until closed - try: - await asyncio.wait_for(self._closed.wait(), timeout=interval) - except asyncio.TimeoutError: - pass + await CrossSync.condition_wait(self._closed, timeout=interval) if not self._closed.is_set() and self._staged_entries: self._schedule_flush() @@ -296,14 +298,14 @@ async def append(self, mutation_entry: RowMutationEntry): ): self._schedule_flush() # yield to the event loop to allow flush to run - await asyncio.sleep(0) + CrossSync.yield_to_event_loop() - def _schedule_flush(self) -> asyncio.Future[None] | None: + def _schedule_flush(self) -> CrossSync.Future[None] | None: """Update the flush task to include the latest staged entries""" if self._staged_entries: entries, self._staged_entries = self._staged_entries, [] self._staged_count, self._staged_bytes = 0, 0 - new_task = self._create_bg_task(self._flush_internal, entries) + new_task = CrossSync.create_task(self._flush_internal, entries, sync_executor=self._sync_executor) if not new_task.done(): self._flush_jobs.add(new_task) new_task.add_done_callback(self._flush_jobs.remove) @@ -318,9 +320,9 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): - new_entries: list of RowMutationEntry objects to flush """ # flush new entries - in_process_requests: list[asyncio.Future[list[FailedMutationEntryError]]] = [] + in_process_requests: list[CrossSync.Future[list[FailedMutationEntryError]]] = [] async for batch in self._flow_control.add_to_flow(new_entries): - batch_task = self._create_bg_task(self._execute_mutate_rows, batch) + batch_task = CrossSync.create_task(self._execute_mutate_rows, batch, sync_executor=self._sync_executor) in_process_requests.append(batch_task) # wait for all inflight requests to complete found_exceptions = await self._wait_for_batch_results(*in_process_requests) @@ -427,12 +429,18 @@ async def close(self): self._closed.set() self._flush_timer.cancel() self._schedule_flush() - if self._flush_jobs: - await asyncio.gather(*self._flush_jobs, return_exceptions=True) - try: - await self._flush_timer - except asyncio.CancelledError: - pass + if CrossSync.is_async: + # flush remaining tasks + if self._flush_jobs: + await asyncio.gather(*self._flush_jobs, return_exceptions=True) + try: + await self._flush_timer + except asyncio.CancelledError: + pass + else: + # shut down executor + with self._sync_executor: + self._sync_executor.shutdown(wait=True) atexit.unregister(self._on_exit) # raise unreported exceptions self._raise_exceptions() @@ -447,26 +455,10 @@ def _on_exit(self): f"{len(self._staged_entries)} Unflushed mutations will not be sent to the server." ) - @staticmethod - def _create_bg_task(func, *args, **kwargs) -> asyncio.Future[Any]: - """ - Create a new background task, and return a future - - This method wraps asyncio to make it easier to maintain subclasses - with different concurrency models. - - Args: - - func: function to execute in background task - - *args: positional arguments to pass to func - - **kwargs: keyword arguments to pass to func - Returns: - - Future object representing the background task - """ - return asyncio.create_task(func(*args, **kwargs)) @staticmethod async def _wait_for_batch_results( - *tasks: asyncio.Future[list[FailedMutationEntryError]] | asyncio.Future[None], + *tasks: CrossSync.Future[list[FailedMutationEntryError]] | CrossSync.Future[None], ) -> list[Exception]: """ Takes in a list of futures representing _execute_mutate_rows tasks, @@ -482,19 +474,19 @@ async def _wait_for_batch_results( """ if not tasks: return [] - all_results = await asyncio.gather(*tasks, return_exceptions=True) - found_errors = [] - for result in all_results: - if isinstance(result, Exception): - # will receive direct Exception objects if request task fails - found_errors.append(result) - elif isinstance(result, BaseException): - # BaseException not expected from grpc calls. Raise immediately - raise result - elif result: - # completed requests will return a list of FailedMutationEntryError - for e in result: - # strip index information - e.index = None - found_errors.extend(result) - return found_errors + exceptions: list[Exception] = [] + for task in tasks: + if CrossSync.is_async: + # futures don't need to be awaited in sync mode + await task + try: + exc_list = task.result() + if exc_list: + # expect a list of FailedMutationEntryError objects + for exc in exc_list: + # strip index information + exc.index = None + exceptions.extend(exc_list) + except Exception as e: + exceptions.append(e) + return exceptions diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index af34b4c45..8fc504281 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -191,7 +191,7 @@ def _start_background_channel_refresh(self) -> None: and (not self._is_closed.is_set()) ): for channel_idx in range(self.transport.pool_size): - refresh_task = create_task_sync( + refresh_task = CrossSync.create_task_sync( self._manage_channel, channel_idx, sync_executor=self._executor, @@ -210,7 +210,7 @@ def close(self, timeout: float | None = None): self.transport.close() if self._executor: self._executor.shutdown(wait=False) - wait_sync(self._channel_refresh_tasks, timeout=timeout) + CrossSync.wait_sync(self._channel_refresh_tasks, timeout=timeout) def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None @@ -245,9 +245,9 @@ def _ping_and_warm_instances( ], wait_for_ready=True, ) - for (instance_name, table_name, app_profile_id) in instance_list + for instance_name, table_name, app_profile_id in instance_list ] - result_list = gather_partials_sync( + result_list = CrossSync.gather_partials_sync( partial_list, return_exceptions=True, sync_executor=self._executor ) return [r or None for r in result_list] @@ -517,7 +517,7 @@ def __init__( ) self.default_retryable_errors = default_retryable_errors or () try: - self._register_instance_future = create_task_sync( + self._register_instance_future = CrossSync.create_task_sync( self.client._register_instance, self.instance_id, self, @@ -567,7 +567,7 @@ def read_rows_stream( from any retries that failed - GoogleAPIError: raised if the request encounters an unrecoverable error """ - (operation_timeout, attempt_timeout) = _helpers._get_timeouts( + operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) @@ -725,7 +725,7 @@ def read_rows_sharded( """ if not sharded_query: raise ValueError("empty sharded_query") - (operation_timeout, attempt_timeout) = _helpers._get_timeouts( + operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) timeout_generator = _helpers._attempt_timeout_generator( @@ -750,7 +750,7 @@ def read_rows_sharded( ) for query in batch ] - batch_result = gather_partials_sync( + batch_result = CrossSync.gather_partials_sync( batch_partial_list, return_exceptions=True, sync_executor=self.client._executor, @@ -767,7 +767,7 @@ def read_rows_sharded( raise ShardedReadRowsExceptionGroup( [ FailedQueryShardError(idx, sharded_query[idx], e) - for (idx, e) in error_dict.items() + for idx, e in error_dict.items() ], results_list, len(sharded_query), @@ -865,7 +865,7 @@ def sample_row_keys( from any retries that failed - GoogleAPIError: raised if the request encounters an unrecoverable error """ - (operation_timeout, attempt_timeout) = _helpers._get_timeouts( + operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) attempt_timeout_gen = _helpers._attempt_timeout_generator( @@ -990,7 +990,7 @@ def mutate_row( safely retried. - ValueError if invalid arguments are provided """ - (operation_timeout, attempt_timeout) = _helpers._get_timeouts( + operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) if not mutations: @@ -1064,7 +1064,7 @@ def bulk_mutate_rows( Contains details about any failed entries in .exceptions - ValueError if invalid arguments are provided """ - (operation_timeout, attempt_timeout) = _helpers._get_timeouts( + operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) @@ -1122,7 +1122,7 @@ def check_and_mutate_row( Raises: - GoogleAPIError exceptions from grpc call """ - (operation_timeout, _) = _helpers._get_timeouts(operation_timeout, None, self) + operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) if true_case_mutations is not None and ( not isinstance(true_case_mutations, list) ): @@ -1181,7 +1181,7 @@ def read_modify_write_row( - GoogleAPIError exceptions from grpc call - ValueError if invalid arguments are provided """ - (operation_timeout, _) = _helpers._get_timeouts(operation_timeout, None, self) + operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) if operation_timeout <= 0: raise ValueError("operation_timeout must be greater than 0") if rules is not None and (not isinstance(rules, list)): diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index c1d27d4b5..c9bbd5b0e 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -97,6 +97,26 @@ def wait_sync(futures, timeout=None): return set(), set() return concurrent.futures.wait(futures, timeout=timeout) + @staticmethod + async def condition_wait(condition, timeout=None): + """ + abstraction over asyncio.Condition.wait + + returns False if the timeout is reached before the condition is set, otherwise True + """ + try: + await asyncio.wait_for(condition.wait(), timeout=timeout) + return True + except asyncio.TimeoutError: + return False + + @staticmethod + def condition_wait_sync(condition, timeout=None): + """ + returns False if the timeout is reached before the condition is set, otherwise True + """ + return condition.wait(timeout=timeout) + @staticmethod def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): """ @@ -119,3 +139,14 @@ def create_task_sync(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwar if not sync_executor: raise ValueError("sync_executor is required for sync version") return sync_executor.submit(fn, *fn_args, **fn_kwargs) + + @staticmethod + async def yield_to_event_loop(): + """ + Call asyncio.sleep(0) to yield to allow other tasks to run + """ + await asyncio.sleep(0) + + @staticmethod + def yield_to_event_loop_sync(): + pass diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 81e46c847..da4b178cc 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -18,7 +18,6 @@ from __future__ import annotations from abc import ABC from collections import deque -from typing import Any from typing import Sequence import asyncio import atexit @@ -90,16 +89,16 @@ def __init__( - batch_retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_mutate_rows_retryable_errors. """ - (self._operation_timeout, self._attempt_timeout) = _get_timeouts( + self._operation_timeout, self._attempt_timeout = _get_timeouts( batch_operation_timeout, batch_attempt_timeout, table ) self._retryable_errors: list[type[Exception]] = _get_retryable_errors( batch_retryable_errors, table ) - self._closed: asyncio.Event = asyncio.Event() + self._closed = threading.Event() self._table = table self._staged_entries: list[RowMutationEntry] = [] - (self._staged_count, self._staged_bytes) = (0, 0) + self._staged_count, self._staged_bytes = (0, 0) self._flow_control = _FlowControlAsync( flow_control_max_mutation_count, flow_control_max_bytes ) @@ -109,8 +108,10 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - self._flush_timer = self._create_bg_task(self._timer_routine, flush_interval) - self._flush_jobs: set[asyncio.Future[None]] = set() + self._flush_timer = CrossSync.create_task_sync( + self._timer_routine, flush_interval, sync_executor=self._sync_executor + ) + self._flush_jobs: set[concurrent.futures.Future[None]] = set() self._entries_processed_since_last_raise: int = 0 self._exceptions_since_last_raise: int = 0 self._exception_list_limit: int = 10 @@ -119,6 +120,10 @@ def __init__( maxlen=self._exception_list_limit ) atexit.register(self._on_exit) + if not False: + self._sync_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) + else: + self._sync_executor = None def _timer_routine(self, interval: float | None) -> None: """ @@ -128,10 +133,7 @@ def _timer_routine(self, interval: float | None) -> None: if not interval or interval <= 0: return None while not self._closed.is_set(): - try: - asyncio.wait_for(self._closed.wait(), timeout=interval) - except asyncio.TimeoutError: - pass + CrossSync.condition_wait_sync(self._closed, timeout=interval) if not self._closed.is_set() and self._staged_entries: self._schedule_flush() @@ -161,14 +163,16 @@ def append(self, mutation_entry: RowMutationEntry): or self._staged_bytes >= self._flush_limit_bytes ): self._schedule_flush() - asyncio.sleep(0) + CrossSync.yield_to_event_loop_sync() - def _schedule_flush(self) -> asyncio.Future[None] | None: + def _schedule_flush(self) -> concurrent.futures.Future[None] | None: """Update the flush task to include the latest staged entries""" if self._staged_entries: - (entries, self._staged_entries) = (self._staged_entries, []) - (self._staged_count, self._staged_bytes) = (0, 0) - new_task = self._create_bg_task(self._flush_internal, entries) + entries, self._staged_entries = (self._staged_entries, []) + self._staged_count, self._staged_bytes = (0, 0) + new_task = CrossSync.create_task_sync( + self._flush_internal, entries, sync_executor=self._sync_executor + ) if not new_task.done(): self._flush_jobs.add(new_task) new_task.add_done_callback(self._flush_jobs.remove) @@ -182,9 +186,13 @@ def _flush_internal(self, new_entries: list[RowMutationEntry]): Args: - new_entries: list of RowMutationEntry objects to flush """ - in_process_requests: list[asyncio.Future[list[FailedMutationEntryError]]] = [] + in_process_requests: list[ + concurrent.futures.Future[list[FailedMutationEntryError]] + ] = [] for batch in self._flow_control.add_to_flow(new_entries): - batch_task = self._create_bg_task(self._execute_mutate_rows, batch) + batch_task = CrossSync.create_task_sync( + self._execute_mutate_rows, batch, sync_executor=self._sync_executor + ) in_process_requests.append(batch_task) found_exceptions = self._wait_for_batch_results(*in_process_requests) self._entries_processed_since_last_raise += len(new_entries) @@ -244,14 +252,14 @@ def _raise_exceptions(self): - MutationsExceptionGroup with all unreported exceptions """ if self._oldest_exceptions or self._newest_exceptions: - (oldest, self._oldest_exceptions) = (self._oldest_exceptions, []) + oldest, self._oldest_exceptions = (self._oldest_exceptions, []) newest = list(self._newest_exceptions) self._newest_exceptions.clear() - (entry_count, self._entries_processed_since_last_raise) = ( + entry_count, self._entries_processed_since_last_raise = ( self._entries_processed_since_last_raise, 0, ) - (exc_count, self._exceptions_since_last_raise) = ( + exc_count, self._exceptions_since_last_raise = ( self._exceptions_since_last_raise, 0, ) @@ -283,12 +291,16 @@ def close(self): self._closed.set() self._flush_timer.cancel() self._schedule_flush() - if self._flush_jobs: - asyncio.gather(*self._flush_jobs, return_exceptions=True) - try: - self._flush_timer - except asyncio.CancelledError: - pass + if False: + if self._flush_jobs: + asyncio.gather(*self._flush_jobs, return_exceptions=True) + try: + self._flush_timer + except asyncio.CancelledError: + pass + else: + with self._sync_executor: + self._sync_executor.shutdown(wait=True) atexit.unregister(self._on_exit) self._raise_exceptions() @@ -299,26 +311,10 @@ def _on_exit(self): f"MutationsBatcher for table {self._table.table_name} was not closed. {len(self._staged_entries)} Unflushed mutations will not be sent to the server." ) - @staticmethod - def _create_bg_task(func, *args, **kwargs) -> asyncio.Future[Any]: - """ - Create a new background task, and return a future - - This method wraps asyncio to make it easier to maintain subclasses - with different concurrency models. - - Args: - - func: function to execute in background task - - *args: positional arguments to pass to func - - **kwargs: keyword arguments to pass to func - Returns: - - Future object representing the background task - """ - return asyncio.create_task(func(*args, **kwargs)) - @staticmethod def _wait_for_batch_results( - *tasks: asyncio.Future[list[FailedMutationEntryError]] | asyncio.Future[None], + *tasks: concurrent.futures.Future[list[FailedMutationEntryError]] + | concurrent.futures.Future[None], ) -> list[Exception]: """ Takes in a list of futures representing _execute_mutate_rows tasks, @@ -334,15 +330,16 @@ def _wait_for_batch_results( """ if not tasks: return [] - all_results = asyncio.gather(*tasks, return_exceptions=True) - found_errors = [] - for result in all_results: - if isinstance(result, Exception): - found_errors.append(result) - elif isinstance(result, BaseException): - raise result - elif result: - for e in result: - e.index = None - found_errors.extend(result) - return found_errors + exceptions: list[Exception] = [] + for task in tasks: + if False: + task + try: + exc_list = task.result() + if exc_list: + for exc in exc_list: + exc.index = None + exceptions.extend(exc_list) + except Exception as e: + exceptions.append(e) + return exceptions diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 76619187f..d3f074e60 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -57,9 +57,11 @@ def __init__(self, *, name=None, cross_sync_replacements=None, text_replacements "Task": "concurrent.futures.Future", "Event": "threading.Event", "is_async": "False", - "gather_partials": "gather_partials_sync", - "wait": "wait_sync", - "create_task": "create_task_sync", + "gather_partials": "CrossSync.gather_partials_sync", + "wait": "CrossSync.wait_sync", + "condition_wait": "CrossSync.condition_wait_sync", + "create_task": "CrossSync.create_task_sync", + "yield_to_event_loop": "CrossSync.yield_to_event_loop_sync", } self.text_replacements = text_replacements or {} self.drop_methods = drop_methods or [] From 7b4614beaee96233e3a0c97c5179362920525ffd Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 30 May 2024 16:06:53 -0700 Subject: [PATCH 060/360] made it work with private classes --- .../bigtable/data/_async/_mutate_rows.py | 5 +- .../cloud/bigtable/data/_sync/_mutate_rows.py | 191 ++++++++++++++++++ .../cloud/bigtable/data/_sync/cross_sync.py | 2 + .../bigtable/data/_sync/mutations_batcher.py | 116 +++++++++++ sync_surface_generator.py | 16 +- 5 files changed, 324 insertions(+), 6 deletions(-) create mode 100644 google/cloud/bigtable/data/_sync/_mutate_rows.py diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index aed14d338..88170b75e 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -29,6 +29,8 @@ # mutate_rows requests are limited to this number of mutations from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + if TYPE_CHECKING: from google.cloud.bigtable_v2.services.bigtable.async_client import ( BigtableAsyncClient, @@ -47,6 +49,7 @@ class _EntryWithProto: proto: types_pb.MutateRowsRequest.Entry +@CrossSync.sync_output("google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation") class _MutateRowsOperationAsync: """ MutateRowsOperation manages the logic of sending a set of row mutations, @@ -103,7 +106,7 @@ def __init__( sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) # Note: _operation could be a raw coroutine, but using a lambda # wrapper helps unify with sync code - self._operation = lambda: retries.retry_target_async( + self._operation = lambda: CrossSync.retry_target( self._run_attempt, self.is_retryable, sleep_generator, diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py new file mode 100644 index 000000000..7d8e20cdd --- /dev/null +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -0,0 +1,191 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by sync_surface_generator.py. Do not edit. + + +from __future__ import annotations +from abc import ABC +from typing import Sequence +import functools + +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto +from google.cloud.bigtable.data._async.client import TableAsync +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data.mutations import RowMutationEntry +from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient + + +@CrossSync.sync_output( + "google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation" +) +class _MutateRowsOperation(ABC): + """ + MutateRowsOperation manages the logic of sending a set of row mutations, + and retrying on failed entries. It manages this using the _run_attempt + function, which attempts to mutate all outstanding entries, and raises + _MutateRowsIncomplete if any retryable errors are encountered. + + Errors are exposed as a MutationsExceptionGroup, which contains a list of + exceptions organized by the related failed mutation entries. + """ + + def __init__( + self, + gapic_client: "BigtableAsyncClient", + table: "TableAsync", + mutation_entries: list["RowMutationEntry"], + operation_timeout: float, + attempt_timeout: float | None, + retryable_exceptions: Sequence[type[Exception]] = (), + ): + """ + Args: + - gapic_client: the client to use for the mutate_rows call + - table: the table associated with the request + - mutation_entries: a list of RowMutationEntry objects to send to the server + - operation_timeout: the timeout to use for the entire operation, in seconds. + - attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds. + If not specified, the request will run until operation_timeout is reached. + """ + total_mutations = sum((len(entry.mutations) for entry in mutation_entries)) + if total_mutations > _MUTATE_ROWS_REQUEST_MUTATION_LIMIT: + raise ValueError( + f"mutate_rows requests can contain at most {_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across all entries. Found {total_mutations}." + ) + metadata = _make_metadata(table.table_name, table.app_profile_id) + self._gapic_fn = functools.partial( + gapic_client.mutate_rows, + table_name=table.table_name, + app_profile_id=table.app_profile_id, + metadata=metadata, + retry=None, + ) + self.is_retryable = retries.if_exception_type( + *retryable_exceptions, bt_exceptions._MutateRowsIncomplete + ) + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + self._operation = lambda: retries.retry_target( + self._run_attempt, + self.is_retryable, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) + self.timeout_generator = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries] + self.remaining_indices = list(range(len(self.mutations))) + self.errors: dict[int, list[Exception]] = {} + + def start(self): + """ + Start the operation, and run until completion + + Raises: + - MutationsExceptionGroup: if any mutations failed + """ + try: + self._operation() + except Exception as exc: + incomplete_indices = self.remaining_indices.copy() + for idx in incomplete_indices: + self._handle_entry_error(idx, exc) + finally: + all_errors: list[Exception] = [] + for idx, exc_list in self.errors.items(): + if len(exc_list) == 0: + raise core_exceptions.ClientError( + f"Mutation {idx} failed with no associated errors" + ) + elif len(exc_list) == 1: + cause_exc = exc_list[0] + else: + cause_exc = bt_exceptions.RetryExceptionGroup(exc_list) + entry = self.mutations[idx].entry + all_errors.append( + bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc) + ) + if all_errors: + raise bt_exceptions.MutationsExceptionGroup( + all_errors, len(self.mutations) + ) + + def _run_attempt(self): + """ + Run a single attempt of the mutate_rows rpc. + + Raises: + - _MutateRowsIncomplete: if there are failed mutations eligible for + retry after the attempt is complete + - GoogleAPICallError: if the gapic rpc fails + """ + request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] + active_request_indices = { + req_idx: orig_idx for req_idx, orig_idx in enumerate(self.remaining_indices) + } + self.remaining_indices = [] + if not request_entries: + return + try: + result_generator = self._gapic_fn( + timeout=next(self.timeout_generator), + entries=request_entries, + retry=None, + ) + for result_list in result_generator: + for result in result_list.entries: + orig_idx = active_request_indices[result.index] + entry_error = core_exceptions.from_grpc_status( + result.status.code, + result.status.message, + details=result.status.details, + ) + if result.status.code != 0: + self._handle_entry_error(orig_idx, entry_error) + elif orig_idx in self.errors: + del self.errors[orig_idx] + del active_request_indices[result.index] + except Exception as exc: + for idx in active_request_indices.values(): + self._handle_entry_error(idx, exc) + raise + if self.remaining_indices: + raise bt_exceptions._MutateRowsIncomplete + + def _handle_entry_error(self, idx: int, exc: Exception): + """ + Add an exception to the list of exceptions for a given mutation index, + and add the index to the list of remaining indices if the exception is + retryable. + + Args: + - idx: the index of the mutation that failed + - exc: the exception to add to the list + """ + entry = self.mutations[idx].entry + self.errors.setdefault(idx, []).append(exc) + if ( + entry.is_idempotent() + and self.is_retryable(exc) + and (idx not in self.remaining_indices) + ): + self.remaining_indices.append(idx) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index c9bbd5b0e..5f9b60483 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -16,6 +16,7 @@ import asyncio import sys import concurrent.futures +import google.api_core.retry as retries class CrossSync: @@ -29,6 +30,7 @@ class CrossSync: Future = asyncio.Future Task = asyncio.Task Event = asyncio.Event + retry_target = retries.retry_target_async @classmethod def sync_output(cls, sync_path): diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index da4b178cc..db7c1953a 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -34,6 +34,7 @@ from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.mutations import Mutation from google.cloud.bigtable.data.mutations import RowMutationEntry +from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT @CrossSync.sync_output( @@ -343,3 +344,118 @@ def _wait_for_batch_results( except Exception as e: exceptions.append(e) return exceptions + + +@CrossSync.sync_output( + "google.cloud.bigtable.data._sync.mutations_batcher._FlowControl" +) +class _FlowControl(ABC): + """ + Manages flow control for batched mutations. Mutations are registered against + the FlowControl object before being sent, which will block if size or count + limits have reached capacity. As mutations completed, they are removed from + the FlowControl object, which will notify any blocked requests that there + is additional capacity. + + Flow limits are not hard limits. If a single mutation exceeds the configured + limits, it will be allowed as a single batch when the capacity is available. + """ + + def __init__(self, max_mutation_count: int, max_mutation_bytes: int): + """ + Args: + - max_mutation_count: maximum number of mutations to send in a single rpc. + This corresponds to individual mutations in a single RowMutationEntry. + - max_mutation_bytes: maximum number of bytes to send in a single rpc. + """ + self._max_mutation_count = max_mutation_count + self._max_mutation_bytes = max_mutation_bytes + if self._max_mutation_count < 1: + raise ValueError("max_mutation_count must be greater than 0") + if self._max_mutation_bytes < 1: + raise ValueError("max_mutation_bytes must be greater than 0") + self._capacity_condition = threading.Condition() + self._in_flight_mutation_count = 0 + self._in_flight_mutation_bytes = 0 + + def _has_capacity(self, additional_count: int, additional_size: int) -> bool: + """ + Checks if there is capacity to send a new entry with the given size and count + + FlowControl limits are not hard limits. If a single mutation exceeds + the configured flow limits, it will be sent in a single batch when + previous batches have completed. + + Args: + - additional_count: number of mutations in the pending entry + - additional_size: size of the pending entry + Returns: + - True if there is capacity to send the pending entry, False otherwise + """ + acceptable_size = max(self._max_mutation_bytes, additional_size) + acceptable_count = max(self._max_mutation_count, additional_count) + new_size = self._in_flight_mutation_bytes + additional_size + new_count = self._in_flight_mutation_count + additional_count + return new_size <= acceptable_size and new_count <= acceptable_count + + def remove_from_flow( + self, mutations: RowMutationEntry | list[RowMutationEntry] + ) -> None: + """ + Removes mutations from flow control. This method should be called once + for each mutation that was sent to add_to_flow, after the corresponding + operation is complete. + + Args: + - mutations: mutation or list of mutations to remove from flow control + """ + if not isinstance(mutations, list): + mutations = [mutations] + total_count = sum((len(entry.mutations) for entry in mutations)) + total_size = sum((entry.size() for entry in mutations)) + self._in_flight_mutation_count -= total_count + self._in_flight_mutation_bytes -= total_size + with self._capacity_condition: + self._capacity_condition.notify_all() + + def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): + """ + Generator function that registers mutations with flow control. As mutations + are accepted into the flow control, they are yielded back to the caller, + to be sent in a batch. If the flow control is at capacity, the generator + will block until there is capacity available. + + Args: + - mutations: list mutations to break up into batches + Yields: + - list of mutations that have reserved space in the flow control. + Each batch contains at least one mutation. + """ + if not isinstance(mutations, list): + mutations = [mutations] + start_idx = 0 + end_idx = 0 + while end_idx < len(mutations): + start_idx = end_idx + batch_mutation_count = 0 + with self._capacity_condition: + while end_idx < len(mutations): + next_entry = mutations[end_idx] + next_size = next_entry.size() + next_count = len(next_entry.mutations) + if ( + self._has_capacity(next_count, next_size) + and batch_mutation_count + next_count + <= _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + ): + end_idx += 1 + batch_mutation_count += next_count + self._in_flight_mutation_bytes += next_size + self._in_flight_mutation_count += next_count + elif start_idx != end_idx: + break + else: + self._capacity_condition.wait_for( + lambda: self._has_capacity(next_count, next_size) + ) + yield mutations[start_idx:end_idx] diff --git a/sync_surface_generator.py b/sync_surface_generator.py index d3f074e60..f2249648b 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -56,6 +56,7 @@ def __init__(self, *, name=None, cross_sync_replacements=None, text_replacements "Future": "concurrent.futures.Future", "Task": "concurrent.futures.Future", "Event": "threading.Event", + "retry_target": "retries.retry_target", "is_async": "False", "gather_partials": "CrossSync.gather_partials_sync", "wait": "CrossSync.wait_sync", @@ -406,10 +407,15 @@ def transform_from_config(config_dict: dict): # with open(save_path, "w") as f: # f.write(code) # find all classes in the library - import google.cloud.bigtable.data._async as data_lib - lib_classes = inspect.getmembers(data_lib, inspect.isclass) - # keep only those with CrossSync annotation - enabled_classes = [c[1] for c in lib_classes if hasattr(c[1], "cross_sync_enabled")] + lib_root = "google/cloud/bigtable/data/_async" + lib_files = [f"{lib_root}/{f}" for f in os.listdir(lib_root) if f.endswith(".py")] + enabled_classes = [] + for file in lib_files: + file_module = file.replace("/", ".")[:-3] + for cls_name, cls in inspect.getmembers(importlib.import_module(file_module), inspect.isclass): + # keep only those with CrossSync annotation + if hasattr(cls, "cross_sync_enabled") and not cls in enabled_classes: + enabled_classes.append(cls) # bucket classes by output location all_paths = {c.cross_sync_file_path for c in enabled_classes} class_map = {loc: [c for c in enabled_classes if c.cross_sync_file_path == loc] for loc in all_paths} @@ -468,7 +474,7 @@ def transform_from_config(config_dict: dict): full_code = f"{header}\n\n{import_str}\n\n{ast.unparse(combined_tree)}" full_code = autoflake.fix_code(full_code, remove_all_unused_imports=True) formatted_code = format_str(full_code, mode=FileMode()) - print(f"saving {async_class.cross_sync_class_name} to {output_file}...") + print(f"saving {[c.cross_sync_class_name for c in class_map[output_file]]} to {output_file}...") with open(output_file, "w") as f: f.write(formatted_code) From d69164a64ca884f81290a9aa41afd233e4ab95df Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 13:34:24 -0700 Subject: [PATCH 061/360] got broken unit tests working for batcher --- .../cloud/bigtable/data/_async/mutations_batcher.py | 11 +++++------ tests/unit/data/_async/test_mutations_batcher.py | 6 +++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index c7f0660ff..ed44b44c0 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -238,6 +238,11 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) + # in sync mode, use a threadpool executor for background tasks + if not CrossSync.is_async: + self._sync_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) + else: + self._sync_executor = None self._flush_timer = CrossSync.create_task(self._timer_routine, flush_interval, sync_executor=self._sync_executor) self._flush_jobs: set[CrossSync.Future[None]] = set() # MutationExceptionGroup reports number of successful entries along with failures @@ -251,12 +256,6 @@ def __init__( ) # clean up on program exit atexit.register(self._on_exit) - # in sync mode, use a threadpool executor for background tasks - if not CrossSync.is_async: - self._sync_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) - else: - self._sync_executor = None - async def _timer_routine(self, interval: float | None) -> None: """ Triggers new flush tasks every `interval` seconds diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 072bc7545..db6dbd725 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -1215,14 +1215,14 @@ async def test_customizable_retryable_errors( down to the gapic layer. """ retryn_fn = ( - "retry_target_async" + "google.cloud.bigtable.data._sync.cross_sync.CrossSync.retry_target" if "Async" in self._get_target_class().__name__ - else "retry_target" + else "google.api_core.retry.retry_target" ) with mock.patch.object( google.api_core.retry, "if_exception_type" ) as predicate_builder_mock: - with mock.patch.object(google.api_core.retry, retryn_fn) as retry_fn_mock: + with mock.patch(retryn_fn) as retry_fn_mock: table = None with mock.patch("asyncio.create_task"): table = TableAsync(mock.Mock(), "instance", "table") From 63fe1d72aa12bb6aa255b3d90d69520e880a8cd0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 15:03:17 -0700 Subject: [PATCH 062/360] index for replaced file names --- .../bigtable/data/_async/mutations_batcher.py | 8 +++++--- google/cloud/bigtable/data/_sync/cross_sync.py | 14 +++++++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index ed44b44c0..cda304756 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -188,7 +188,7 @@ class MutationsBatcherAsync: def __init__( self, - table: "TableAsync", + table: CrossSync[TableAsync], *, flush_interval: float | None = 5, flush_limit_mutation_count: int | None = 1000, @@ -229,7 +229,7 @@ def __init__( self._table = table self._staged_entries: list[RowMutationEntry] = [] self._staged_count, self._staged_bytes = 0, 0 - self._flow_control = _FlowControlAsync( + self._flow_control = CrossSync[_FlowControlAsync]( flow_control_max_mutation_count, flow_control_max_bytes ) self._flush_limit_bytes = flush_limit_bytes @@ -344,7 +344,7 @@ async def _execute_mutate_rows( FailedMutationEntryError objects will not contain index information """ try: - operation = _MutateRowsOperationAsync( + operation = CrossSync[_MutateRowsOperationAsync]( self._table.client._gapic_client, self._table, batch, @@ -405,10 +405,12 @@ def _raise_exceptions(self): entry_count=entry_count, ) + @CrossSync.rename_sync("__enter__") async def __aenter__(self): """For context manager API""" return self + @CrossSync.rename_sync("__exit__") async def __aexit__(self, exc_type, exc, tb): """For context manager API""" await self.close() diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 5f9b60483..0a3ce293f 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -18,8 +18,12 @@ import concurrent.futures import google.api_core.retry as retries +class _AsyncGetAttr(type): -class CrossSync: + def __getitem__(cls, item): + return item + +class CrossSync(metaclass=_AsyncGetAttr): SyncImports = False is_async = True @@ -31,11 +35,19 @@ class CrossSync: Task = asyncio.Task Event = asyncio.Event retry_target = retries.retry_target_async + generated_replacements = {} + + @staticmethod + def rename_sync(*args, **kwargs): + def decorator(func): + return func + return decorator @classmethod def sync_output(cls, sync_path): # return the async class unchanged def decorator(async_cls): + cls.generated_replacements[async_cls] = sync_path async_cls.cross_sync_enabled = True async_cls.cross_sync_import_path = sync_path async_cls.cross_sync_class_name = sync_path.rsplit('.', 1)[-1] From d40c016acaa5de8553fa306ce08dbcc4629e319b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 15:31:29 -0700 Subject: [PATCH 063/360] simplified surface generator --- .../cloud/bigtable/data/_sync/_mutate_rows.py | 6 +- google/cloud/bigtable/data/_sync/client.py | 10 +- .../cloud/bigtable/data/_sync/cross_sync.py | 114 ++++++++------ .../bigtable/data/_sync/mutations_batcher.py | 48 +++--- sync_surface_generator.py | 140 +++--------------- 5 files changed, 125 insertions(+), 193 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index 7d8e20cdd..19cb389b4 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -27,13 +27,13 @@ from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._sync.cross_sync import _CrossSync_Sync from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient -@CrossSync.sync_output( +@_CrossSync_Sync.sync_output( "google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation" ) class _MutateRowsOperation(ABC): @@ -82,7 +82,7 @@ def __init__( *retryable_exceptions, bt_exceptions._MutateRowsIncomplete ) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - self._operation = lambda: retries.retry_target( + self._operation = lambda: _CrossSync_Sync.retry_target( self._run_attempt, self.is_retryable, sleep_generator, diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 8fc504281..f6860063a 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -174,7 +174,7 @@ def __init__( @staticmethod def _client_version() -> str: """Helper function to return the client version string for this client""" - if False: + if _CrossSync_Sync.is_async: return f"{google.cloud.bigtable.__version__}-data-async" else: return f"{google.cloud.bigtable.__version__}-data" @@ -191,7 +191,7 @@ def _start_background_channel_refresh(self) -> None: and (not self._is_closed.is_set()) ): for channel_idx in range(self.transport.pool_size): - refresh_task = CrossSync.create_task_sync( + refresh_task = _CrossSync_Sync.create_task( self._manage_channel, channel_idx, sync_executor=self._executor, @@ -210,7 +210,7 @@ def close(self, timeout: float | None = None): self.transport.close() if self._executor: self._executor.shutdown(wait=False) - CrossSync.wait_sync(self._channel_refresh_tasks, timeout=timeout) + _CrossSync_Sync.wait(self._channel_refresh_tasks, timeout=timeout) def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None @@ -247,7 +247,7 @@ def _ping_and_warm_instances( ) for instance_name, table_name, app_profile_id in instance_list ] - result_list = CrossSync.gather_partials_sync( + result_list = _CrossSync_Sync.gather_partials( partial_list, return_exceptions=True, sync_executor=self._executor ) return [r or None for r in result_list] @@ -278,7 +278,7 @@ def _manage_channel( grace_period: time to allow previous channel to serve existing requests before closing, in seconds """ - sleep_fn = asyncio.sleep if False else self._is_closed.wait + sleep_fn = asyncio.sleep if _CrossSync_Sync.is_async else self._is_closed.wait first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 0a3ce293f..cf4422998 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -17,6 +17,10 @@ import sys import concurrent.futures import google.api_core.retry as retries +import time +import threading +import queue + class _AsyncGetAttr(type): @@ -73,26 +77,6 @@ async def gather_partials(partial_list, return_exceptions=False, sync_executor=N awaitable_list = [partial() for partial in partial_list] return await asyncio.gather(*awaitable_list, return_exceptions=return_exceptions) - @staticmethod - def gather_partials_sync(partial_list, return_exceptions=False, sync_executor=None): - if not partial_list: - return [] - if not sync_executor: - raise ValueError("sync_executor is required for sync version") - futures_list = [ - sync_executor.submit(partial) for partial in partial_list - ] - results_list = [] - for future in futures_list: - if future.exception(): - if return_exceptions: - results_list.append(future.exception()) - else: - raise future.exception() - else: - results_list.append(future.result()) - return results_list - @staticmethod async def wait(futures, timeout=None): """ @@ -102,15 +86,6 @@ async def wait(futures, timeout=None): return set(), set() return await asyncio.wait(futures, timeout=timeout) - @staticmethod - def wait_sync(futures, timeout=None): - """ - abstraction over asyncio.wait - """ - if not futures: - return set(), set() - return concurrent.futures.wait(futures, timeout=timeout) - @staticmethod async def condition_wait(condition, timeout=None): """ @@ -124,13 +99,6 @@ async def condition_wait(condition, timeout=None): except asyncio.TimeoutError: return False - @staticmethod - def condition_wait_sync(condition, timeout=None): - """ - returns False if the timeout is reached before the condition is set, otherwise True - """ - return condition.wait(timeout=timeout) - @staticmethod def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): """ @@ -144,23 +112,81 @@ def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): return task @staticmethod - def create_task_sync(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): + async def yield_to_event_loop(): """ - abstraction over asyncio.create_task. Sync version implemented with threadpool executor + Call asyncio.sleep(0) to yield to allow other tasks to run + """ + await asyncio.sleep(0) - sync_executor: ThreadPoolExecutor to use for sync operations. Ignored in async version +class _SyncGetAttr(type): + + def __getitem__(cls, item): + breakpoint() + return CrossSync.generated_replacements[item] + + +class _CrossSync_Sync(metaclass=_SyncGetAttr): + + is_async = False + + sleep = time.sleep + Queue = queue.Queue + Condition = threading.Condition + Future = concurrent.futures.Future + Task = concurrent.futures.Future + Event = threading.Event + retry_target = retries.retry_target + generated_replacements = {} + + @staticmethod + def wait(futures, timeout=None): + """ + abstraction over asyncio.wait + """ + if not futures: + return set(), set() + return concurrent.futures.wait(futures, timeout=timeout) + + @staticmethod + def condition_wait(condition, timeout=None): """ + returns False if the timeout is reached before the condition is set, otherwise True + """ + return condition.wait(timeout=timeout) + + @staticmethod + def gather_partials(partial_list, return_exceptions=False, sync_executor=None): + if not partial_list: + return [] if not sync_executor: raise ValueError("sync_executor is required for sync version") - return sync_executor.submit(fn, *fn_args, **fn_kwargs) + futures_list = [ + sync_executor.submit(partial) for partial in partial_list + ] + results_list = [] + for future in futures_list: + if future.exception(): + if return_exceptions: + results_list.append(future.exception()) + else: + raise future.exception() + else: + results_list.append(future.result()) + return results_list + + @staticmethod - async def yield_to_event_loop(): + def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): """ - Call asyncio.sleep(0) to yield to allow other tasks to run + abstraction over asyncio.create_task. Sync version implemented with threadpool executor + + sync_executor: ThreadPoolExecutor to use for sync operations. Ignored in async version """ - await asyncio.sleep(0) + if not sync_executor: + raise ValueError("sync_executor is required for sync version") + return sync_executor.submit(fn, *fn_args, **fn_kwargs) @staticmethod - def yield_to_event_loop_sync(): + def yield_to_event_loop(): pass diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index db7c1953a..a52f00c11 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -29,7 +29,7 @@ from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._sync.cross_sync import _CrossSync_Sync from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.mutations import Mutation @@ -37,7 +37,7 @@ from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT -@CrossSync.sync_output( +@_CrossSync_Sync.sync_output( "google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher" ) class MutationsBatcher(ABC): @@ -60,7 +60,7 @@ class MutationsBatcher(ABC): def __init__( self, - table: "TableAsync", + table: _CrossSync_Sync[TableAsync], *, flush_interval: float | None = 5, flush_limit_mutation_count: int | None = 1000, @@ -96,11 +96,11 @@ def __init__( self._retryable_errors: list[type[Exception]] = _get_retryable_errors( batch_retryable_errors, table ) - self._closed = threading.Event() + self._closed = _CrossSync_Sync.Event() self._table = table self._staged_entries: list[RowMutationEntry] = [] self._staged_count, self._staged_bytes = (0, 0) - self._flow_control = _FlowControlAsync( + self._flow_control = _CrossSync_Sync[_FlowControlAsync]( flow_control_max_mutation_count, flow_control_max_bytes ) self._flush_limit_bytes = flush_limit_bytes @@ -109,10 +109,14 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - self._flush_timer = CrossSync.create_task_sync( + if not _CrossSync_Sync.is_async: + self._sync_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) + else: + self._sync_executor = None + self._flush_timer = _CrossSync_Sync.create_task( self._timer_routine, flush_interval, sync_executor=self._sync_executor ) - self._flush_jobs: set[concurrent.futures.Future[None]] = set() + self._flush_jobs: set[_CrossSync_Sync.Future[None]] = set() self._entries_processed_since_last_raise: int = 0 self._exceptions_since_last_raise: int = 0 self._exception_list_limit: int = 10 @@ -121,10 +125,6 @@ def __init__( maxlen=self._exception_list_limit ) atexit.register(self._on_exit) - if not False: - self._sync_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) - else: - self._sync_executor = None def _timer_routine(self, interval: float | None) -> None: """ @@ -134,7 +134,7 @@ def _timer_routine(self, interval: float | None) -> None: if not interval or interval <= 0: return None while not self._closed.is_set(): - CrossSync.condition_wait_sync(self._closed, timeout=interval) + _CrossSync_Sync.condition_wait(self._closed, timeout=interval) if not self._closed.is_set() and self._staged_entries: self._schedule_flush() @@ -164,14 +164,14 @@ def append(self, mutation_entry: RowMutationEntry): or self._staged_bytes >= self._flush_limit_bytes ): self._schedule_flush() - CrossSync.yield_to_event_loop_sync() + _CrossSync_Sync.yield_to_event_loop() - def _schedule_flush(self) -> concurrent.futures.Future[None] | None: + def _schedule_flush(self) -> _CrossSync_Sync.Future[None] | None: """Update the flush task to include the latest staged entries""" if self._staged_entries: entries, self._staged_entries = (self._staged_entries, []) self._staged_count, self._staged_bytes = (0, 0) - new_task = CrossSync.create_task_sync( + new_task = _CrossSync_Sync.create_task( self._flush_internal, entries, sync_executor=self._sync_executor ) if not new_task.done(): @@ -188,10 +188,10 @@ def _flush_internal(self, new_entries: list[RowMutationEntry]): - new_entries: list of RowMutationEntry objects to flush """ in_process_requests: list[ - concurrent.futures.Future[list[FailedMutationEntryError]] + _CrossSync_Sync.Future[list[FailedMutationEntryError]] ] = [] for batch in self._flow_control.add_to_flow(new_entries): - batch_task = CrossSync.create_task_sync( + batch_task = _CrossSync_Sync.create_task( self._execute_mutate_rows, batch, sync_executor=self._sync_executor ) in_process_requests.append(batch_task) @@ -214,7 +214,7 @@ def _execute_mutate_rows( FailedMutationEntryError objects will not contain index information """ try: - operation = _MutateRowsOperationAsync( + operation = _CrossSync_Sync[_MutateRowsOperationAsync]( self._table.client._gapic_client, self._table, batch, @@ -292,7 +292,7 @@ def close(self): self._closed.set() self._flush_timer.cancel() self._schedule_flush() - if False: + if _CrossSync_Sync.is_async: if self._flush_jobs: asyncio.gather(*self._flush_jobs, return_exceptions=True) try: @@ -314,8 +314,8 @@ def _on_exit(self): @staticmethod def _wait_for_batch_results( - *tasks: concurrent.futures.Future[list[FailedMutationEntryError]] - | concurrent.futures.Future[None], + *tasks: _CrossSync_Sync.Future[list[FailedMutationEntryError]] + | _CrossSync_Sync.Future[None], ) -> list[Exception]: """ Takes in a list of futures representing _execute_mutate_rows tasks, @@ -333,7 +333,7 @@ def _wait_for_batch_results( return [] exceptions: list[Exception] = [] for task in tasks: - if False: + if _CrossSync_Sync.is_async: task try: exc_list = task.result() @@ -346,7 +346,7 @@ def _wait_for_batch_results( return exceptions -@CrossSync.sync_output( +@_CrossSync_Sync.sync_output( "google.cloud.bigtable.data._sync.mutations_batcher._FlowControl" ) class _FlowControl(ABC): @@ -374,7 +374,7 @@ def __init__(self, max_mutation_count: int, max_mutation_bytes: int): raise ValueError("max_mutation_count must be greater than 0") if self._max_mutation_bytes < 1: raise ValueError("max_mutation_bytes must be greater than 0") - self._capacity_condition = threading.Condition() + self._capacity_condition = _CrossSync_Sync.Condition() self._in_flight_mutation_count = 0 self._in_flight_mutation_bytes = 0 diff --git a/sync_surface_generator.py b/sync_surface_generator.py index f2249648b..b696c8570 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -37,11 +37,11 @@ class AsyncToSyncTransformer(ast.NodeTransformer): outside of this autogeneration system are always applied """ - def __init__(self, *, name=None, cross_sync_replacements=None, text_replacements=None, drop_methods=None, pass_methods=None, error_methods=None, replace_methods=None): + def __init__(self, *, name=None, asyncio_replacements=None, text_replacements=None, drop_methods=None, pass_methods=None, error_methods=None, replace_methods=None): """ Args: - name: the name of the class being processed. Just used in exceptions - - cross_sync_replacements: CrossSync functionality to replace + - asyncio_replacements: asyncio functionality to replace - text_replacements: dict of text to replace directly in the source code and docstrings - drop_methods: list of method names to drop from the class - pass_methods: list of method names to replace with "pass" in the class @@ -49,21 +49,7 @@ def __init__(self, *, name=None, cross_sync_replacements=None, text_replacements - replace_methods: dict of method names to replace with custom code """ self.name = name - self.cross_sync_replacements = cross_sync_replacements or { - "sleep": "time.sleep", - "Queue": "queue.Queue", - "Condition": "threading.Condition", - "Future": "concurrent.futures.Future", - "Task": "concurrent.futures.Future", - "Event": "threading.Event", - "retry_target": "retries.retry_target", - "is_async": "False", - "gather_partials": "CrossSync.gather_partials_sync", - "wait": "CrossSync.wait_sync", - "condition_wait": "CrossSync.condition_wait_sync", - "create_task": "CrossSync.create_task_sync", - "yield_to_event_loop": "CrossSync.yield_to_event_loop_sync", - } + self.asyncio_replacements = asyncio_replacements or {} self.text_replacements = text_replacements or {} self.drop_methods = drop_methods or [] self.pass_methods = pass_methods or [] @@ -115,27 +101,23 @@ def visit_AsyncFunctionDef(self, node): if len(parsed.body) > 0: new_body.append(parsed.body[0]) node.body = new_body - # else: + else: # check if the function contains non-replaced usage of asyncio - # func_ast = ast.parse(ast.unparse(node)) - # for n in ast.walk(func_ast): - # if isinstance(n, ast.Call) \ - # and isinstance(n.func, ast.Attribute) \ - # and isinstance(n.func.value, ast.Name) \ - # and n.func.value.id == "CrossSync" \ - # and n.func.attr not in self.cross_sync_replacements: - # path_str = f"{self.name}.{node.name}" if self.name else node.name - # breakpoint() - # raise RuntimeError( - # f"{path_str} contains unhandled asyncio calls: {n.func.attr}. Add method to drop_methods, pass_methods, or error_methods to handle safely." - # ) + func_ast = ast.parse(ast.unparse(node)) + for n in ast.walk(func_ast): + if isinstance(n, ast.Call) \ + and isinstance(n.func, ast.Attribute) \ + and isinstance(n.func.value, ast.Name) \ + and n.func.value.id == "asyncio" \ + and n.func.attr not in self.asyncio_replacements: + path_str = f"{self.name}.{node.name}" if self.name else node.name + print(f"{path_str} contains unhandled asyncio calls: {n.func.attr}. Add method to drop_methods, pass_methods, or error_methods to handle safely.") # remove pytest.mark.asyncio decorator if hasattr(node, "decorator_list"): # TODO: make generic is_asyncio_decorator = lambda d: all(x in ast.dump(d) for x in ["pytest", "mark", "asyncio"]) - is_cross_sync_decorator = lambda d: all(x in ast.dump(d) for x in ["CrossSync", "sync_output"]) node.decorator_list = [ - d for d in node.decorator_list if not is_asyncio_decorator(d) and not is_cross_sync_decorator(d) + d for d in node.decorator_list if not is_asyncio_decorator(d) ] # visit string type annotations @@ -171,10 +153,10 @@ def visit_Attribute(self, node): if ( isinstance(node.value, ast.Name) and isinstance(node.value.ctx, ast.Load) - and node.value.id == "CrossSync" - and node.attr in self.cross_sync_replacements + and node.value.id == "asyncio" + and node.attr in self.asyncio_replacements ): - replacement = self.cross_sync_replacements[node.attr] + replacement = self.asyncio_replacements[node.attr] return ast.copy_location(ast.parse(replacement, mode="eval").body, node) fixed = ast.copy_location( ast.Attribute( @@ -266,7 +248,7 @@ def _create_error_node(node, error_msg): def get_imports(self, filename): """ - Get the imports from a file, and do a find-and-replace against cross_sync_replacements + Get the imports from a file, and do a find-and-replace against asyncio_replacements """ imports = set() with open(filename, "r") as f: @@ -276,14 +258,14 @@ def get_imports(self, filename): for alias in node.names: if isinstance(node, ast.Import): # import statments - new_import = self.cross_sync_replacements.get(alias.name, alias.name) + new_import = self.asyncio_replacements.get(alias.name, alias.name) imports.add(ast.parse(f"import {new_import}").body[0]) else: # import from statements # break into individual components full_path = f"{node.module}.{alias.name}" - if full_path in self.cross_sync_replacements: - full_path = self.cross_sync_replacements[full_path] + if full_path in self.asyncio_replacements: + full_path = self.asyncio_replacements[full_path] module, name = full_path.rsplit(".", 1) # don't import from same file if module == ".": @@ -320,6 +302,7 @@ def transform_class(in_obj: Type, **kwargs): # find imports imports = transformer.get_imports(filename) imports.add(ast.parse("from abc import ABC").body[0]) + imports.add(ast.parse("from google.cloud.bigtable.data._sync.cross_sync import _CrossSync_Sync").body[0]) # add locals from file, in case they are needed if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): with open(filename, "r") as f: @@ -333,68 +316,6 @@ def transform_class(in_obj: Type, **kwargs): return ast_tree.body, imports -def transform_from_config(config_dict: dict): - # initialize new tree and import list - combined_tree = ast.parse("") - combined_imports = set() - # add new concrete classes to text_replacements - global_text_replacements = config_dict.get("text_replacements", {}) - for class_dict in config_dict["classes"]: - if "concrete_path" in class_dict: - class_name = class_dict["path"].rsplit(".", 1)[1] - global_text_replacements[class_name] = class_dict.pop("concrete_path") - # process each class - for class_dict in config_dict["classes"]: - # convert string class path into class object - module_path, class_name = class_dict.pop("path").rsplit(".", 1) - class_object = getattr(importlib.import_module(module_path), class_name) - # add globals to class_dict - class_dict["asyncio_replacements"] = {**config_dict.get("asyncio_replacements", {}), **class_dict.get("asyncio_replacements", {})} - class_dict["text_replacements"] = {**global_text_replacements, **class_dict.get("text_replacements", {})} - # add class-specific imports - for import_str in class_dict.pop("added_imports", []): - combined_imports.add(ast.parse(import_str).body[0]) - # transform class - tree_body, imports = transform_class(class_object, **class_dict) - # update combined data - combined_tree.body.extend(tree_body) - combined_imports.update(imports) - # add extra imports - for import_str in config_dict.get("added_imports", []): - combined_imports.add(ast.parse(import_str).body[0]) - # render tree as string of code - import_unique = list(set([ast.unparse(i) for i in combined_imports])) - import_unique.sort() - google, non_google = [], [] - for i in import_unique: - if "google" in i: - google.append(i) - else: - non_google.append(i) - import_str = "\n".join(non_google + [""] + google) - # append clean tree - header = """# Copyright 2024 Google LLC - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - # - # This file is automatically generated by sync_surface_generator.py. Do not edit. - """ - full_code = f"{header}\n\n{import_str}\n\n{ast.unparse(combined_tree)}" - full_code = autoflake.fix_code(full_code, remove_all_unused_imports=True) - formatted_code = format_str(full_code, mode=FileMode()) - return formatted_code - - if __name__ == "__main__": # for load_path in ["./google/cloud/bigtable/data/_sync/sync_gen.yaml", "./google/cloud/bigtable/data/_sync/unit_tests.yaml"]: # for load_path in ["./google/cloud/bigtable/data/_sync/sync_gen.yaml"]: @@ -425,22 +346,7 @@ def transform_from_config(config_dict: dict): combined_tree = ast.parse("") combined_imports = set() for async_class in class_map[output_file]: - class_dict = { - "text_replacements": { - "__anext__": "__next__", - "__aenter__": "__enter__", - "__aexit__": "__exit__", - "__aiter__": "__iter__", - "aclose": "close", - "AsyncIterable": "Iterable", - "AsyncIterator": "Iterator", - "StopAsyncIteration": "StopIteration", - "Awaitable": None, - "CrossSync.Event": "threading.Event", - }, - "autogen_sync_name": async_class.cross_sync_class_name, - } - tree_body, imports = transform_class(async_class, **class_dict) + tree_body, imports = transform_class(async_class, autogen_sync_name=async_class.cross_sync_class_name, text_replacements={"CrossSync": "_CrossSync_Sync"}) # update combined data combined_tree.body.extend(tree_body) combined_imports.update(imports) From 888cfc2a40b69a2df0694c9cd6d79687d9e84427 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 15:45:31 -0700 Subject: [PATCH 064/360] implemented rename_sync decorator --- sync_surface_generator.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sync_surface_generator.py b/sync_surface_generator.py index b696c8570..b2e4b76ab 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -115,6 +115,17 @@ def visit_AsyncFunctionDef(self, node): # remove pytest.mark.asyncio decorator if hasattr(node, "decorator_list"): # TODO: make generic + new_list = [] + for decorator in node.decorator_list: + # check for cross_sync decorator + if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name) and decorator.func.value.id == "CrossSync": + decorator_type = decorator.func.attr + if decorator_type == "rename_sync": + new_name = decorator.args[0].value + node.name = new_name + else: + new_list.append(decorator) + node.decorator_list = new_list is_asyncio_decorator = lambda d: all(x in ast.dump(d) for x in ["pytest", "mark", "asyncio"]) node.decorator_list = [ d for d in node.decorator_list if not is_asyncio_decorator(d) From 6d553054a74e7171a7af70a4088afd41a358a345 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 15:59:49 -0700 Subject: [PATCH 065/360] remove class decorator and abc super class --- .../cloud/bigtable/data/_sync/_mutate_rows.py | 6 +----- google/cloud/bigtable/data/_sync/client.py | 18 ++++++++++-------- .../bigtable/data/_sync/mutations_batcher.py | 11 ++--------- sync_surface_generator.py | 14 +++++++++----- 4 files changed, 22 insertions(+), 27 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index 19cb389b4..2f79e5ad8 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -16,7 +16,6 @@ from __future__ import annotations -from abc import ABC from typing import Sequence import functools @@ -33,10 +32,7 @@ from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient -@_CrossSync_Sync.sync_output( - "google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation" -) -class _MutateRowsOperation(ABC): +class _MutateRowsOperation: """ MutateRowsOperation manages the logic of sending a set of row mutations, and retrying on failed entries. It manages this using the _run_attempt diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index f6860063a..2545eb7a9 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -47,7 +47,7 @@ from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._sync.cross_sync import _CrossSync_Sync from google.cloud.bigtable.data.exceptions import FailedQueryShardError from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.cloud.bigtable.data.mutations import Mutation @@ -75,8 +75,7 @@ import google.auth.credentials -@CrossSync.sync_output("google.cloud.bigtable.data._sync.client.BigtableDataClient") -class BigtableDataClient(ClientWithProject, ABC): +class BigtableDataClient(ClientWithProject): def __init__( self, *, @@ -149,7 +148,11 @@ def __init__( self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() self._channel_refresh_tasks: list[asyncio.Task[None]] = [] - self._executor = concurrent.futures.ThreadPoolExecutor() if not False else None + self._executor = ( + concurrent.futures.ThreadPoolExecutor() + if not _CrossSync_Sync.is_async + else None + ) if self._emulator_host is not None: warnings.warn( "Connecting to Bigtable emulator at {}".format(self._emulator_host), @@ -405,8 +408,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._gapic_client.__exit__(exc_type, exc_val, exc_tb) -@CrossSync.sync_output("google.cloud.bigtable.data._sync.client.Table") -class Table(ABC): +class Table: """ Main Data API surface @@ -517,7 +519,7 @@ def __init__( ) self.default_retryable_errors = default_retryable_errors or () try: - self._register_instance_future = CrossSync.create_task_sync( + self._register_instance_future = _CrossSync_Sync.create_task( self.client._register_instance, self.instance_id, self, @@ -750,7 +752,7 @@ def read_rows_sharded( ) for query in batch ] - batch_result = CrossSync.gather_partials_sync( + batch_result = _CrossSync_Sync.gather_partials( batch_partial_list, return_exceptions=True, sync_executor=self.client._executor, diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index a52f00c11..6f892bc94 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -16,7 +16,6 @@ from __future__ import annotations -from abc import ABC from collections import deque from typing import Sequence import asyncio @@ -37,10 +36,7 @@ from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT -@_CrossSync_Sync.sync_output( - "google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher" -) -class MutationsBatcher(ABC): +class MutationsBatcher: """ Allows users to send batches using context manager API: @@ -346,10 +342,7 @@ def _wait_for_batch_results( return exceptions -@_CrossSync_Sync.sync_output( - "google.cloud.bigtable.data._sync.mutations_batcher._FlowControl" -) -class _FlowControl(ABC): +class _FlowControl: """ Manages flow control for batched mutations. Mutations are registered against the FlowControl object before being sent, which will block if size or count diff --git a/sync_surface_generator.py b/sync_surface_generator.py index b2e4b76ab..c30e7aa78 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -295,16 +295,20 @@ def transform_class(in_obj: Type, **kwargs): ast_tree = ast.parse(textwrap.dedent("".join(lines)), filename) new_name = None if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): + cls_node = ast_tree.body[0] + # remove cross_sync decorator + if hasattr(cls_node, "decorator_list"): + cls_node.decorator_list = [d for d in cls_node.decorator_list if not isinstance(d, ast.Call) or not isinstance(d.func, ast.Attribute) or not isinstance(d.func.value, ast.Name) or d.func.value.id != "CrossSync"] # update name - old_name = ast_tree.body[0].name + old_name = cls_node.name # set default name for new class if unset new_name = kwargs.pop("autogen_sync_name", f"{old_name}_SyncGen") - ast_tree.body[0].name = new_name + cls_node.name = new_name ast.increment_lineno(ast_tree, lineno - 1) # add ABC as base class - ast_tree.body[0].bases = ast_tree.body[0].bases + [ - ast.Name("ABC", ast.Load()), - ] + # cls_node.bases = ast_tree.body[0].bases + [ + # ast.Name("ABC", ast.Load()), + # ] # remove top-level imports if any. Add them back later ast_tree.body = [n for n in ast_tree.body if not isinstance(n, (ast.Import, ast.ImportFrom))] # transform From 605f8748f76f3e048ec5c71ac0bd70f2a4f06293 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 16:12:13 -0700 Subject: [PATCH 066/360] added support for class-based replacements --- google/cloud/bigtable/data/_async/client.py | 12 ++++++++++-- google/cloud/bigtable/data/_sync/client.py | 1 - google/cloud/bigtable/data/_sync/cross_sync.py | 4 +++- sync_surface_generator.py | 3 ++- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 752fa7047..5bafa524f 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -83,7 +83,10 @@ from google.cloud.bigtable.data._helpers import ShardedQuery -@CrossSync.sync_output("google.cloud.bigtable.data._sync.client.BigtableDataClient") +@CrossSync.sync_output( + "google.cloud.bigtable.data._sync.client.BigtableDataClient", + replace_symbols={"__aenter__": "__enter__", "__aexit__": "__exit__"} +) class BigtableDataClientAsync(ClientWithProject): def __init__( @@ -437,7 +440,10 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) -@CrossSync.sync_output("google.cloud.bigtable.data._sync.client.Table") +@CrossSync.sync_output( + "google.cloud.bigtable.data._sync.client.Table", + replace_symbols={"AsyncIterable": "Iterable"} +) class TableAsync: """ Main Data API surface @@ -1262,6 +1268,7 @@ async def close(self): self._register_instance_future.cancel() await self.client._remove_instance_registration(self.instance_id, self) + @CrossSync.rename_sync("__enter__") async def __aenter__(self): """ Implement async context manager protocol @@ -1273,6 +1280,7 @@ async def __aenter__(self): await self._register_instance_future return self + @CrossSync.rename_sync("__exit__") async def __aexit__(self, exc_type, exc_val, exc_tb): """ Implement async context manager protocol diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 2545eb7a9..fe3fd8d53 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -16,7 +16,6 @@ from __future__ import annotations -from abc import ABC from functools import partial from grpc import Channel from typing import Any diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index cf4422998..a554622b7 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -48,7 +48,8 @@ def decorator(func): return decorator @classmethod - def sync_output(cls, sync_path): + def sync_output(cls, sync_path, replace_symbols=None): + replace_symbols = replace_symbols or {} # return the async class unchanged def decorator(async_cls): cls.generated_replacements[async_cls] = sync_path @@ -56,6 +57,7 @@ def decorator(async_cls): async_cls.cross_sync_import_path = sync_path async_cls.cross_sync_class_name = sync_path.rsplit('.', 1)[-1] async_cls.cross_sync_file_path = "/".join(sync_path.split(".")[:-1]) + ".py" + async_cls.cross_sync_replace_symbols = replace_symbols return async_cls return decorator diff --git a/sync_surface_generator.py b/sync_surface_generator.py index c30e7aa78..7074d628c 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -361,7 +361,8 @@ def transform_class(in_obj: Type, **kwargs): combined_tree = ast.parse("") combined_imports = set() for async_class in class_map[output_file]: - tree_body, imports = transform_class(async_class, autogen_sync_name=async_class.cross_sync_class_name, text_replacements={"CrossSync": "_CrossSync_Sync"}) + text_replacements = {"CrossSync": "_CrossSync_Sync", **async_class.cross_sync_replace_symbols} + tree_body, imports = transform_class(async_class, autogen_sync_name=async_class.cross_sync_class_name, text_replacements=text_replacements) # update combined data combined_tree.body.extend(tree_body) combined_imports.update(imports) From e930e541e6d9be2f5d90172fd1f3545651f0fc4e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 16:18:55 -0700 Subject: [PATCH 067/360] replaced async classes --- .../bigtable/data/_async/_mutate_rows.py | 8 ++- google/cloud/bigtable/data/_async/client.py | 16 +++-- .../cloud/bigtable/data/_sync/_mutate_rows.py | 6 +- google/cloud/bigtable/data/_sync/client.py | 64 ++++++++----------- 4 files changed, 47 insertions(+), 47 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 88170b75e..c576735b2 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -49,7 +49,13 @@ class _EntryWithProto: proto: types_pb.MutateRowsRequest.Entry -@CrossSync.sync_output("google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation") +@CrossSync.sync_output( + "google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", + replace_symbols={ + "BigtableAsyncClient": "BigtableClient", + "TableAsync": "Table", + } +) class _MutateRowsOperationAsync: """ MutateRowsOperation manages the logic of sending a set of row mutations, diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 5bafa524f..0d324952e 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -85,7 +85,11 @@ @CrossSync.sync_output( "google.cloud.bigtable.data._sync.client.BigtableDataClient", - replace_symbols={"__aenter__": "__enter__", "__aexit__": "__exit__"} + replace_symbols={ + "__aenter__": "__enter__", "__aexit__": "__exit__", + "TableAsync": "Table", "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcIOTransport", + "BigtableAsyncClient": "BigtableClient", "AsyncPooledChannel": "PooledChannel" + } ) class BigtableDataClientAsync(ClientWithProject): @@ -442,7 +446,11 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): @CrossSync.sync_output( "google.cloud.bigtable.data._sync.client.Table", - replace_symbols={"AsyncIterable": "Iterable"} + replace_symbols={ + "AsyncIterable": "Iterable", + "MutationsBatcherAsync": "MutationsBatcher", + "BigtableDataClientAsync": "BigtableDataClient", + } ) class TableAsync: """ @@ -614,7 +622,7 @@ async def read_rows_stream( ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - row_merger = _ReadRowsOperationAsync( + row_merger = CrossSync[_ReadRowsOperationAsync]( query, self, operation_timeout=operation_timeout, @@ -1127,7 +1135,7 @@ async def bulk_mutate_rows( ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - operation = _MutateRowsOperationAsync( + operation = CrossSync[_MutateRowsOperationAsync]( self.client._gapic_client, self, mutation_entries, diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index 2f79e5ad8..0439fb143 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -22,14 +22,12 @@ from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto -from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory from google.cloud.bigtable.data._sync.cross_sync import _CrossSync_Sync from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT -from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient class _MutateRowsOperation: @@ -45,8 +43,8 @@ class _MutateRowsOperation: def __init__( self, - gapic_client: "BigtableAsyncClient", - table: "TableAsync", + gapic_client: "BigtableClient", + table: "Table", mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index fe3fd8d53..0ad417812 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -39,9 +39,6 @@ from google.cloud.bigtable.data import _helpers from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync -from google.cloud.bigtable.data._async.client import BigtableDataClientAsync -from google.cloud.bigtable.data._async.client import TableAsync -from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery @@ -58,15 +55,8 @@ from google.cloud.bigtable.data.row_filters import RowFilter from google.cloud.bigtable.data.row_filters import RowFilterChain from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter -from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, -) -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel as AsyncPooledChannel, -) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest from google.cloud.client import ClientWithProject from google.cloud.environment_vars import BIGTABLE_EMULATOR @@ -112,7 +102,7 @@ def __init__( - ValueError if pool_size is less than 1 """ transport_str = f"bt-{self._client_version()}-{pool_size}" - transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) + transport = PooledBigtableGrpcIOTransport.with_fixed_size(pool_size) BigtableClientMeta._transport_registry[transport_str] = transport client_info = DEFAULT_CLIENT_INFO client_info.client_library_version = self._client_version() @@ -133,7 +123,7 @@ def __init__( project=project, client_options=client_options, ) - self._gapic_client = BigtableAsyncClient( + self._gapic_client = BigtableClient( transport=transport_str, credentials=credentials, client_options=client_options, @@ -141,7 +131,7 @@ def __init__( ) self._is_closed = asyncio.Event() self.transport = cast( - PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport + PooledBigtableGrpcIOTransport, self._gapic_client.transport ) self._active_instances: Set[_helpers._WarmedInstanceKey] = set() self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} @@ -158,7 +148,7 @@ def __init__( RuntimeWarning, stacklevel=2, ) - self.transport._grpc_channel = AsyncPooledChannel( + self.transport._grpc_channel = PooledChannel( pool_size=pool_size, host=self._emulator_host, insecure=True ) self.transport._stubs = {} @@ -304,7 +294,7 @@ def _manage_channel( next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = next_refresh - (time.monotonic() - start_timestamp) - def _register_instance(self, instance_id: str, owner: TableAsync) -> None: + def _register_instance(self, instance_id: str, owner: Table) -> None: """ Registers an instance with the client, and warms the channel pool for the instance @@ -331,9 +321,7 @@ def _register_instance(self, instance_id: str, owner: TableAsync) -> None: else: self._start_background_channel_refresh() - def _remove_instance_registration( - self, instance_id: str, owner: TableAsync - ) -> bool: + def _remove_instance_registration(self, instance_id: str, owner: Table) -> bool: """ Removes an instance from the client's registered instances, to prevent warming new channels for the instance @@ -361,10 +349,10 @@ def _remove_instance_registration( except KeyError: return False - def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: + def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> Table: """ Returns a table instance for making data API requests. All arguments are passed - directly to the TableAsync constructor. + directly to the Table constructor. Args: instance_id: The Bigtable instance ID to associate with this client. @@ -396,7 +384,7 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs encountered during all other operations. Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) """ - return TableAsync(self, instance_id, table_id, *args, **kwargs) + return Table(self, instance_id, table_id, *args, **kwargs) def __enter__(self): self._start_background_channel_refresh() @@ -417,7 +405,7 @@ class Table: def __init__( self, - client: BigtableDataClientAsync, + client: BigtableDataClient, instance_id: str, table_id: str, app_profile_id: str | None = None, @@ -545,7 +533,7 @@ def read_rows_stream( Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -572,7 +560,7 @@ def read_rows_stream( operation_timeout, attempt_timeout, self ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - row_merger = _ReadRowsOperationAsync( + row_merger = _CrossSync_Sync[_ReadRowsOperationAsync]( query, self, operation_timeout=operation_timeout, @@ -598,7 +586,7 @@ def read_rows( Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -647,7 +635,7 @@ def read_row( Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -705,7 +693,7 @@ def read_rows_sharded( results = await table.read_rows_sharded(shard_queries) ``` - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -788,7 +776,7 @@ def row_exists( Return a boolean indicating whether the specified row exists in the table. uses the filters: chain(limit cells per row = 1, strip value) - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -844,7 +832,7 @@ def sample_row_keys( RowKeySamples is simply a type alias for list[tuple[bytes, int]]; a list of row_keys, along with offset positions in the table - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -907,14 +895,14 @@ def mutations_batcher( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - ) -> MutationsBatcherAsync: + ) -> MutationsBatcher: """ Returns a new mutations batcher instance. Can be used to iteratively add mutations that are flushed as a group, to avoid excess network calls - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -933,9 +921,9 @@ def mutations_batcher( - batch_retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_mutate_rows_retryable_errors. Returns: - - a MutationsBatcherAsync context manager that can batch requests + - a MutationsBatcher context manager that can batch requests """ - return MutationsBatcherAsync( + return MutationsBatcher( self, flush_interval=flush_interval, flush_limit_mutation_count=flush_limit_mutation_count, @@ -966,7 +954,7 @@ def mutate_row( Idempotent operations (i.e, all mutations have an explicit timestamp) will be retried on server failure. Non-idempotent operations will not. - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1043,7 +1031,7 @@ def bulk_mutate_rows( will be retried on failure. Non-idempotent will not, and will reported in a raised exception group - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1069,7 +1057,7 @@ def bulk_mutate_rows( operation_timeout, attempt_timeout, self ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - operation = _MutateRowsOperationAsync( + operation = _CrossSync_Sync[_MutateRowsOperationAsync]( self.client._gapic_client, self, mutation_entries, @@ -1093,7 +1081,7 @@ def check_and_mutate_row( Non-idempotent operation: will not be retried - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: @@ -1164,7 +1152,7 @@ def read_modify_write_row( Non-idempotent operation: will not be retried - Warning: BigtableDataClientAsync is currently in preview, and is not + Warning: BigtableDataClient is currently in preview, and is not yet recommended for production use. Args: From 96f92b8b449de569fdfa75e70ba09dddae070352 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 16:33:09 -0700 Subject: [PATCH 068/360] added generation to ReadRowsOperation --- .../cloud/bigtable/data/_async/_read_rows.py | 3 + .../cloud/bigtable/data/_sync/_read_rows.py | 282 ++++++++++++++++++ 2 files changed, 285 insertions(+) create mode 100644 google/cloud/bigtable/data/_sync/_read_rows.py diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index b50cc9adc..9ab7003a4 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -36,6 +36,8 @@ from google.api_core import retry as retries from google.api_core.retry import exponential_sleep_generator +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + if TYPE_CHECKING: from google.cloud.bigtable.data._async.client import TableAsync @@ -45,6 +47,7 @@ def __init__(self, chunk): self.chunk = chunk +@CrossSync.sync_output("google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation") class _ReadRowsOperationAsync: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py new file mode 100644 index 000000000..7f478afe9 --- /dev/null +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -0,0 +1,282 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by sync_surface_generator.py. Do not edit. + + +from __future__ import annotations +from typing import AsyncIterable +from typing import Awaitable +from typing import Sequence + +from google.api_core import retry as retries +from google.api_core.retry import exponential_sleep_generator +from google.cloud.bigtable.data import _helpers +from google.cloud.bigtable.data._async._read_rows import _ResetRow +from google.cloud.bigtable.data._async.client import TableAsync +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.exceptions import _RowSetComplete +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.row import Cell +from google.cloud.bigtable.data.row import Row +from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB +from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB +from google.cloud.bigtable_v2.types import RowRange as RowRangePB +from google.cloud.bigtable_v2.types import RowSet as RowSetPB + + +class _ReadRowsOperation: + """ + ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream + into a stream of Row objects. + + ReadRowsOperation.merge_row_response_stream takes in a stream of ReadRowsResponse + and turns them into a stream of Row objects using an internal + StateMachine. + + ReadRowsOperation(request, client) handles row merging logic end-to-end, including + performing retries on stream errors. + """ + + __slots__ = ( + "attempt_timeout_gen", + "operation_timeout", + "request", + "table", + "_predicate", + "_metadata", + "_last_yielded_row_key", + "_remaining_count", + ) + + def __init__( + self, + query: ReadRowsQuery, + table: "TableAsync", + operation_timeout: float, + attempt_timeout: float, + retryable_exceptions: Sequence[type[Exception]] = (), + ): + self.attempt_timeout_gen = _helpers._attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + self.operation_timeout = operation_timeout + if isinstance(query, dict): + self.request = ReadRowsRequestPB( + **query, + table_name=table.table_name, + app_profile_id=table.app_profile_id, + ) + else: + self.request = query._to_pb(table) + self.table = table + self._predicate = retries.if_exception_type(*retryable_exceptions) + self._metadata = _helpers._make_metadata(table.table_name, table.app_profile_id) + self._last_yielded_row_key: bytes | None = None + self._remaining_count: int | None = self.request.rows_limit or None + + def start_operation(self) -> AsyncIterable[Row]: + """Start the read_rows operation, retrying on retryable errors.""" + return retries.retry_target_stream_async( + self._read_rows_attempt, + self._predicate, + exponential_sleep_generator(0.01, 60, multiplier=2), + self.operation_timeout, + exception_factory=_helpers._retry_exception_factory, + ) + + def _read_rows_attempt(self) -> AsyncIterable[Row]: + """ + Attempt a single read_rows rpc call. + This function is intended to be wrapped by retry logic, + which will call this function until it succeeds or + a non-retryable error is raised. + """ + if self._last_yielded_row_key is not None: + try: + self.request.rows = self._revise_request_rowset( + row_set=self.request.rows, + last_seen_row_key=self._last_yielded_row_key, + ) + except _RowSetComplete: + return self.merge_rows(None) + if self._remaining_count is not None: + self.request.rows_limit = self._remaining_count + if self._remaining_count == 0: + return self.merge_rows(None) + gapic_stream = self.table.client._gapic_client.read_rows( + self.request, + timeout=next(self.attempt_timeout_gen), + metadata=self._metadata, + retry=None, + ) + chunked_stream = self.chunk_stream(gapic_stream) + return self.merge_rows(chunked_stream) + + def chunk_stream( + self, stream: Awaitable[AsyncIterable[ReadRowsResponsePB]] + ) -> AsyncIterable[ReadRowsResponsePB.CellChunk]: + """process chunks out of raw read_rows stream""" + for resp in stream: + resp = resp._pb + if resp.last_scanned_row_key: + if ( + self._last_yielded_row_key is not None + and resp.last_scanned_row_key <= self._last_yielded_row_key + ): + raise InvalidChunk("last scanned out of order") + self._last_yielded_row_key = resp.last_scanned_row_key + current_key = None + for c in resp.chunks: + if current_key is None: + current_key = c.row_key + if current_key is None: + raise InvalidChunk("first chunk is missing a row key") + elif ( + self._last_yielded_row_key + and current_key <= self._last_yielded_row_key + ): + raise InvalidChunk("row keys should be strictly increasing") + yield c + if c.reset_row: + current_key = None + elif c.commit_row: + self._last_yielded_row_key = current_key + if self._remaining_count is not None: + self._remaining_count -= 1 + if self._remaining_count < 0: + raise InvalidChunk("emit count exceeds row limit") + current_key = None + + @staticmethod + def merge_rows(chunks: AsyncIterable[ReadRowsResponsePB.CellChunk] | None): + """Merge chunks into rows""" + if chunks is None: + return + it = chunks.__aiter__() + while True: + try: + c = it.__anext__() + except StopAsyncIteration: + return + row_key = c.row_key + if not row_key: + raise InvalidChunk("first row chunk is missing key") + cells = [] + family: str | None = None + qualifier: bytes | None = None + try: + while True: + if c.reset_row: + raise _ResetRow(c) + k = c.row_key + f = c.family_name.value + q = c.qualifier.value if c.HasField("qualifier") else None + if k and k != row_key: + raise InvalidChunk("unexpected new row key") + if f: + family = f + if q is not None: + qualifier = q + else: + raise InvalidChunk("new family without qualifier") + elif family is None: + raise InvalidChunk("missing family") + elif q is not None: + if family is None: + raise InvalidChunk("new qualifier without family") + qualifier = q + elif qualifier is None: + raise InvalidChunk("missing qualifier") + ts = c.timestamp_micros + labels = c.labels if c.labels else [] + value = c.value + if c.value_size > 0: + buffer = [value] + while c.value_size > 0: + c = it.__anext__() + t = c.timestamp_micros + cl = c.labels + k = c.row_key + if ( + c.HasField("family_name") + and c.family_name.value != family + ): + raise InvalidChunk("family changed mid cell") + if ( + c.HasField("qualifier") + and c.qualifier.value != qualifier + ): + raise InvalidChunk("qualifier changed mid cell") + if t and t != ts: + raise InvalidChunk("timestamp changed mid cell") + if cl and cl != labels: + raise InvalidChunk("labels changed mid cell") + if k and k != row_key: + raise InvalidChunk("row key changed mid cell") + if c.reset_row: + raise _ResetRow(c) + buffer.append(c.value) + value = b"".join(buffer) + cells.append( + Cell(value, row_key, family, qualifier, ts, list(labels)) + ) + if c.commit_row: + yield Row(row_key, cells) + break + c = it.__anext__() + except _ResetRow as e: + c = e.chunk + if ( + c.row_key + or c.HasField("family_name") + or c.HasField("qualifier") + or c.timestamp_micros + or c.labels + or c.value + ): + raise InvalidChunk("reset row with data") + continue + except StopAsyncIteration: + raise InvalidChunk("premature end of stream") + + @staticmethod + def _revise_request_rowset(row_set: RowSetPB, last_seen_row_key: bytes) -> RowSetPB: + """ + Revise the rows in the request to avoid ones we've already processed. + + Args: + - row_set: the row set from the request + - last_seen_row_key: the last row key encountered + Raises: + - _RowSetComplete: if there are no rows left to process after the revision + """ + if row_set is None or (not row_set.row_ranges and row_set.row_keys is not None): + last_seen = last_seen_row_key + return RowSetPB(row_ranges=[RowRangePB(start_key_open=last_seen)]) + adjusted_keys: list[bytes] = [ + k for k in row_set.row_keys if k > last_seen_row_key + ] + adjusted_ranges: list[RowRangePB] = [] + for row_range in row_set.row_ranges: + end_key = row_range.end_key_closed or row_range.end_key_open or None + if end_key is None or end_key > last_seen_row_key: + new_range = RowRangePB(row_range) + start_key = row_range.start_key_closed or row_range.start_key_open + if start_key is None or start_key <= last_seen_row_key: + new_range.start_key_open = last_seen_row_key + adjusted_ranges.append(new_range) + if len(adjusted_keys) == 0 and len(adjusted_ranges) == 0: + raise _RowSetComplete() + return RowSetPB(row_keys=adjusted_keys, row_ranges=adjusted_ranges) From f37aa388535e78ace0f4483277d64eafdc0a0c72 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 16:36:41 -0700 Subject: [PATCH 069/360] set up symbol replacement for ReadRowsOperation --- .../cloud/bigtable/data/_async/_read_rows.py | 6 +++++- .../cloud/bigtable/data/_sync/_read_rows.py | 19 ++++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 9ab7003a4..04fc4072c 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -47,7 +47,11 @@ def __init__(self, chunk): self.chunk = chunk -@CrossSync.sync_output("google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation") +@CrossSync.sync_output("google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", + replace_symbols={ + "AsyncIterable": "Iterable", "StopAsyncIteration": "StopIteration", "Awaitable": None, "TableAsync": "Table", + } +) class _ReadRowsOperationAsync: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index 7f478afe9..d6b84912c 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -16,15 +16,12 @@ from __future__ import annotations -from typing import AsyncIterable -from typing import Awaitable from typing import Sequence from google.api_core import retry as retries from google.api_core.retry import exponential_sleep_generator from google.cloud.bigtable.data import _helpers from google.cloud.bigtable.data._async._read_rows import _ResetRow -from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _RowSetComplete from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery @@ -63,7 +60,7 @@ class _ReadRowsOperation: def __init__( self, query: ReadRowsQuery, - table: "TableAsync", + table: "Table", operation_timeout: float, attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), @@ -86,7 +83,7 @@ def __init__( self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None - def start_operation(self) -> AsyncIterable[Row]: + def start_operation(self) -> Iterable[Row]: """Start the read_rows operation, retrying on retryable errors.""" return retries.retry_target_stream_async( self._read_rows_attempt, @@ -96,7 +93,7 @@ def start_operation(self) -> AsyncIterable[Row]: exception_factory=_helpers._retry_exception_factory, ) - def _read_rows_attempt(self) -> AsyncIterable[Row]: + def _read_rows_attempt(self) -> Iterable[Row]: """ Attempt a single read_rows rpc call. This function is intended to be wrapped by retry logic, @@ -125,8 +122,8 @@ def _read_rows_attempt(self) -> AsyncIterable[Row]: return self.merge_rows(chunked_stream) def chunk_stream( - self, stream: Awaitable[AsyncIterable[ReadRowsResponsePB]] - ) -> AsyncIterable[ReadRowsResponsePB.CellChunk]: + self, stream: Iterable[ReadRowsResponsePB] + ) -> Iterable[ReadRowsResponsePB.CellChunk]: """process chunks out of raw read_rows stream""" for resp in stream: resp = resp._pb @@ -160,7 +157,7 @@ def chunk_stream( current_key = None @staticmethod - def merge_rows(chunks: AsyncIterable[ReadRowsResponsePB.CellChunk] | None): + def merge_rows(chunks: Iterable[ReadRowsResponsePB.CellChunk] | None): """Merge chunks into rows""" if chunks is None: return @@ -168,7 +165,7 @@ def merge_rows(chunks: AsyncIterable[ReadRowsResponsePB.CellChunk] | None): while True: try: c = it.__anext__() - except StopAsyncIteration: + except StopIteration: return row_key = c.row_key if not row_key: @@ -248,7 +245,7 @@ def merge_rows(chunks: AsyncIterable[ReadRowsResponsePB.CellChunk] | None): ): raise InvalidChunk("reset row with data") continue - except StopAsyncIteration: + except StopIteration: raise InvalidChunk("premature end of stream") @staticmethod From 37549e9ee2b23b7c1882bd4a9383e711597cd404 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 16:38:41 -0700 Subject: [PATCH 070/360] use cross sync for retries --- google/cloud/bigtable/data/_async/_read_rows.py | 2 +- google/cloud/bigtable/data/_sync/_read_rows.py | 3 ++- google/cloud/bigtable/data/_sync/cross_sync.py | 2 ++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 04fc4072c..987f0102b 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -109,7 +109,7 @@ def start_operation(self) -> AsyncIterable[Row]: """ Start the read_rows operation, retrying on retryable errors. """ - return retries.retry_target_stream_async( + return CrossSync.retry_target_stream( self._read_rows_attempt, self._predicate, exponential_sleep_generator(0.01, 60, multiplier=2), diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index d6b84912c..76d7090fc 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -22,6 +22,7 @@ from google.api_core.retry import exponential_sleep_generator from google.cloud.bigtable.data import _helpers from google.cloud.bigtable.data._async._read_rows import _ResetRow +from google.cloud.bigtable.data._sync.cross_sync import _CrossSync_Sync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _RowSetComplete from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery @@ -85,7 +86,7 @@ def __init__( def start_operation(self) -> Iterable[Row]: """Start the read_rows operation, retrying on retryable errors.""" - return retries.retry_target_stream_async( + return _CrossSync_Sync.retry_target_stream( self._read_rows_attempt, self._predicate, exponential_sleep_generator(0.01, 60, multiplier=2), diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index a554622b7..05b2cd503 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -39,6 +39,7 @@ class CrossSync(metaclass=_AsyncGetAttr): Task = asyncio.Task Event = asyncio.Event retry_target = retries.retry_target_async + retry_target_stream = retries.retry_target_stream_async generated_replacements = {} @staticmethod @@ -138,6 +139,7 @@ class _CrossSync_Sync(metaclass=_SyncGetAttr): Task = concurrent.futures.Future Event = threading.Event retry_target = retries.retry_target + retry_target_stream = retries.retry_target_stream generated_replacements = {} @staticmethod From b2975fbf92e57c6daacd27d361a41f7bd84a694c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 16:44:31 -0700 Subject: [PATCH 071/360] refactored how sync implementation is handled --- .../cloud/bigtable/data/_sync/_mutate_rows.py | 4 +- .../cloud/bigtable/data/_sync/_read_rows.py | 4 +- google/cloud/bigtable/data/_sync/client.py | 24 ++-- .../cloud/bigtable/data/_sync/cross_sync.py | 132 +++++++++--------- .../bigtable/data/_sync/mutations_batcher.py | 38 ++--- sync_surface_generator.py | 5 +- 6 files changed, 104 insertions(+), 103 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index 0439fb143..fc0990e13 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -25,7 +25,7 @@ from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory -from google.cloud.bigtable.data._sync.cross_sync import _CrossSync_Sync +from google.cloud.bigtable.data._sync.cross_sync import CrossSync from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT @@ -76,7 +76,7 @@ def __init__( *retryable_exceptions, bt_exceptions._MutateRowsIncomplete ) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - self._operation = lambda: _CrossSync_Sync.retry_target( + self._operation = lambda: CrossSync._Sync_Impl.retry_target( self._run_attempt, self.is_retryable, sleep_generator, diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index 76d7090fc..a5ee1d110 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -22,7 +22,7 @@ from google.api_core.retry import exponential_sleep_generator from google.cloud.bigtable.data import _helpers from google.cloud.bigtable.data._async._read_rows import _ResetRow -from google.cloud.bigtable.data._sync.cross_sync import _CrossSync_Sync +from google.cloud.bigtable.data._sync.cross_sync import CrossSync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _RowSetComplete from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery @@ -86,7 +86,7 @@ def __init__( def start_operation(self) -> Iterable[Row]: """Start the read_rows operation, retrying on retryable errors.""" - return _CrossSync_Sync.retry_target_stream( + return CrossSync._Sync_Impl.retry_target_stream( self._read_rows_attempt, self._predicate, exponential_sleep_generator(0.01, 60, multiplier=2), diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 0ad417812..220eb130d 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -43,7 +43,7 @@ from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data._sync.cross_sync import _CrossSync_Sync +from google.cloud.bigtable.data._sync.cross_sync import CrossSync from google.cloud.bigtable.data.exceptions import FailedQueryShardError from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.cloud.bigtable.data.mutations import Mutation @@ -139,7 +139,7 @@ def __init__( self._channel_refresh_tasks: list[asyncio.Task[None]] = [] self._executor = ( concurrent.futures.ThreadPoolExecutor() - if not _CrossSync_Sync.is_async + if not CrossSync._Sync_Impl.is_async else None ) if self._emulator_host is not None: @@ -166,7 +166,7 @@ def __init__( @staticmethod def _client_version() -> str: """Helper function to return the client version string for this client""" - if _CrossSync_Sync.is_async: + if CrossSync._Sync_Impl.is_async: return f"{google.cloud.bigtable.__version__}-data-async" else: return f"{google.cloud.bigtable.__version__}-data" @@ -183,7 +183,7 @@ def _start_background_channel_refresh(self) -> None: and (not self._is_closed.is_set()) ): for channel_idx in range(self.transport.pool_size): - refresh_task = _CrossSync_Sync.create_task( + refresh_task = CrossSync._Sync_Impl.create_task( self._manage_channel, channel_idx, sync_executor=self._executor, @@ -202,7 +202,7 @@ def close(self, timeout: float | None = None): self.transport.close() if self._executor: self._executor.shutdown(wait=False) - _CrossSync_Sync.wait(self._channel_refresh_tasks, timeout=timeout) + CrossSync._Sync_Impl.wait(self._channel_refresh_tasks, timeout=timeout) def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None @@ -239,7 +239,7 @@ def _ping_and_warm_instances( ) for instance_name, table_name, app_profile_id in instance_list ] - result_list = _CrossSync_Sync.gather_partials( + result_list = CrossSync._Sync_Impl.gather_partials( partial_list, return_exceptions=True, sync_executor=self._executor ) return [r or None for r in result_list] @@ -270,7 +270,9 @@ def _manage_channel( grace_period: time to allow previous channel to serve existing requests before closing, in seconds """ - sleep_fn = asyncio.sleep if _CrossSync_Sync.is_async else self._is_closed.wait + sleep_fn = ( + asyncio.sleep if CrossSync._Sync_Impl.is_async else self._is_closed.wait + ) first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) @@ -506,7 +508,7 @@ def __init__( ) self.default_retryable_errors = default_retryable_errors or () try: - self._register_instance_future = _CrossSync_Sync.create_task( + self._register_instance_future = CrossSync._Sync_Impl.create_task( self.client._register_instance, self.instance_id, self, @@ -560,7 +562,7 @@ def read_rows_stream( operation_timeout, attempt_timeout, self ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - row_merger = _CrossSync_Sync[_ReadRowsOperationAsync]( + row_merger = CrossSync._Sync_Impl[_ReadRowsOperationAsync]( query, self, operation_timeout=operation_timeout, @@ -739,7 +741,7 @@ def read_rows_sharded( ) for query in batch ] - batch_result = _CrossSync_Sync.gather_partials( + batch_result = CrossSync._Sync_Impl.gather_partials( batch_partial_list, return_exceptions=True, sync_executor=self.client._executor, @@ -1057,7 +1059,7 @@ def bulk_mutate_rows( operation_timeout, attempt_timeout, self ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - operation = _CrossSync_Sync[_MutateRowsOperationAsync]( + operation = CrossSync._Sync_Impl[_MutateRowsOperationAsync]( self.client._gapic_client, self, mutation_entries, diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 05b2cd503..224845678 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -27,6 +27,12 @@ class _AsyncGetAttr(type): def __getitem__(cls, item): return item +class _SyncGetAttr(type): + + def __getitem__(cls, item): + breakpoint() + return CrossSync.generated_replacements[item] + class CrossSync(metaclass=_AsyncGetAttr): SyncImports = False @@ -121,76 +127,70 @@ async def yield_to_event_loop(): """ await asyncio.sleep(0) -class _SyncGetAttr(type): - - def __getitem__(cls, item): - breakpoint() - return CrossSync.generated_replacements[item] - - -class _CrossSync_Sync(metaclass=_SyncGetAttr): - - is_async = False - - sleep = time.sleep - Queue = queue.Queue - Condition = threading.Condition - Future = concurrent.futures.Future - Task = concurrent.futures.Future - Event = threading.Event - retry_target = retries.retry_target - retry_target_stream = retries.retry_target_stream - generated_replacements = {} - - @staticmethod - def wait(futures, timeout=None): - """ - abstraction over asyncio.wait - """ - if not futures: - return set(), set() - return concurrent.futures.wait(futures, timeout=timeout) - - @staticmethod - def condition_wait(condition, timeout=None): - """ - returns False if the timeout is reached before the condition is set, otherwise True - """ - return condition.wait(timeout=timeout) - @staticmethod - def gather_partials(partial_list, return_exceptions=False, sync_executor=None): - if not partial_list: - return [] - if not sync_executor: - raise ValueError("sync_executor is required for sync version") - futures_list = [ - sync_executor.submit(partial) for partial in partial_list - ] - results_list = [] - for future in futures_list: - if future.exception(): - if return_exceptions: - results_list.append(future.exception()) + class _Sync_Impl(metaclass=_SyncGetAttr): + + is_async = False + + sleep = time.sleep + Queue = queue.Queue + Condition = threading.Condition + Future = concurrent.futures.Future + Task = concurrent.futures.Future + Event = threading.Event + retry_target = retries.retry_target + retry_target_stream = retries.retry_target_stream + generated_replacements = {} + + @staticmethod + def wait(futures, timeout=None): + """ + abstraction over asyncio.wait + """ + if not futures: + return set(), set() + return concurrent.futures.wait(futures, timeout=timeout) + + @staticmethod + def condition_wait(condition, timeout=None): + """ + returns False if the timeout is reached before the condition is set, otherwise True + """ + return condition.wait(timeout=timeout) + + @staticmethod + def gather_partials(partial_list, return_exceptions=False, sync_executor=None): + if not partial_list: + return [] + if not sync_executor: + raise ValueError("sync_executor is required for sync version") + futures_list = [ + sync_executor.submit(partial) for partial in partial_list + ] + results_list = [] + for future in futures_list: + if future.exception(): + if return_exceptions: + results_list.append(future.exception()) + else: + raise future.exception() else: - raise future.exception() - else: - results_list.append(future.result()) - return results_list + results_list.append(future.result()) + return results_list - @staticmethod - def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): - """ - abstraction over asyncio.create_task. Sync version implemented with threadpool executor + @staticmethod + def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): + """ + abstraction over asyncio.create_task. Sync version implemented with threadpool executor - sync_executor: ThreadPoolExecutor to use for sync operations. Ignored in async version - """ - if not sync_executor: - raise ValueError("sync_executor is required for sync version") - return sync_executor.submit(fn, *fn_args, **fn_kwargs) + sync_executor: ThreadPoolExecutor to use for sync operations. Ignored in async version + """ + if not sync_executor: + raise ValueError("sync_executor is required for sync version") + return sync_executor.submit(fn, *fn_args, **fn_kwargs) - @staticmethod - def yield_to_event_loop(): - pass + @staticmethod + def yield_to_event_loop(): + pass diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 6f892bc94..4a798af3f 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -28,7 +28,7 @@ from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts -from google.cloud.bigtable.data._sync.cross_sync import _CrossSync_Sync +from google.cloud.bigtable.data._sync.cross_sync import CrossSync from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.mutations import Mutation @@ -56,7 +56,7 @@ class MutationsBatcher: def __init__( self, - table: _CrossSync_Sync[TableAsync], + table: CrossSync._Sync_Impl[TableAsync], *, flush_interval: float | None = 5, flush_limit_mutation_count: int | None = 1000, @@ -92,11 +92,11 @@ def __init__( self._retryable_errors: list[type[Exception]] = _get_retryable_errors( batch_retryable_errors, table ) - self._closed = _CrossSync_Sync.Event() + self._closed = CrossSync._Sync_Impl.Event() self._table = table self._staged_entries: list[RowMutationEntry] = [] self._staged_count, self._staged_bytes = (0, 0) - self._flow_control = _CrossSync_Sync[_FlowControlAsync]( + self._flow_control = CrossSync._Sync_Impl[_FlowControlAsync]( flow_control_max_mutation_count, flow_control_max_bytes ) self._flush_limit_bytes = flush_limit_bytes @@ -105,14 +105,14 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - if not _CrossSync_Sync.is_async: + if not CrossSync._Sync_Impl.is_async: self._sync_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) else: self._sync_executor = None - self._flush_timer = _CrossSync_Sync.create_task( + self._flush_timer = CrossSync._Sync_Impl.create_task( self._timer_routine, flush_interval, sync_executor=self._sync_executor ) - self._flush_jobs: set[_CrossSync_Sync.Future[None]] = set() + self._flush_jobs: set[CrossSync._Sync_Impl.Future[None]] = set() self._entries_processed_since_last_raise: int = 0 self._exceptions_since_last_raise: int = 0 self._exception_list_limit: int = 10 @@ -130,7 +130,7 @@ def _timer_routine(self, interval: float | None) -> None: if not interval or interval <= 0: return None while not self._closed.is_set(): - _CrossSync_Sync.condition_wait(self._closed, timeout=interval) + CrossSync._Sync_Impl.condition_wait(self._closed, timeout=interval) if not self._closed.is_set() and self._staged_entries: self._schedule_flush() @@ -160,14 +160,14 @@ def append(self, mutation_entry: RowMutationEntry): or self._staged_bytes >= self._flush_limit_bytes ): self._schedule_flush() - _CrossSync_Sync.yield_to_event_loop() + CrossSync._Sync_Impl.yield_to_event_loop() - def _schedule_flush(self) -> _CrossSync_Sync.Future[None] | None: + def _schedule_flush(self) -> CrossSync._Sync_Impl.Future[None] | None: """Update the flush task to include the latest staged entries""" if self._staged_entries: entries, self._staged_entries = (self._staged_entries, []) self._staged_count, self._staged_bytes = (0, 0) - new_task = _CrossSync_Sync.create_task( + new_task = CrossSync._Sync_Impl.create_task( self._flush_internal, entries, sync_executor=self._sync_executor ) if not new_task.done(): @@ -184,10 +184,10 @@ def _flush_internal(self, new_entries: list[RowMutationEntry]): - new_entries: list of RowMutationEntry objects to flush """ in_process_requests: list[ - _CrossSync_Sync.Future[list[FailedMutationEntryError]] + CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] ] = [] for batch in self._flow_control.add_to_flow(new_entries): - batch_task = _CrossSync_Sync.create_task( + batch_task = CrossSync._Sync_Impl.create_task( self._execute_mutate_rows, batch, sync_executor=self._sync_executor ) in_process_requests.append(batch_task) @@ -210,7 +210,7 @@ def _execute_mutate_rows( FailedMutationEntryError objects will not contain index information """ try: - operation = _CrossSync_Sync[_MutateRowsOperationAsync]( + operation = CrossSync._Sync_Impl[_MutateRowsOperationAsync]( self._table.client._gapic_client, self._table, batch, @@ -288,7 +288,7 @@ def close(self): self._closed.set() self._flush_timer.cancel() self._schedule_flush() - if _CrossSync_Sync.is_async: + if CrossSync._Sync_Impl.is_async: if self._flush_jobs: asyncio.gather(*self._flush_jobs, return_exceptions=True) try: @@ -310,8 +310,8 @@ def _on_exit(self): @staticmethod def _wait_for_batch_results( - *tasks: _CrossSync_Sync.Future[list[FailedMutationEntryError]] - | _CrossSync_Sync.Future[None], + *tasks: CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] + | CrossSync._Sync_Impl.Future[None], ) -> list[Exception]: """ Takes in a list of futures representing _execute_mutate_rows tasks, @@ -329,7 +329,7 @@ def _wait_for_batch_results( return [] exceptions: list[Exception] = [] for task in tasks: - if _CrossSync_Sync.is_async: + if CrossSync._Sync_Impl.is_async: task try: exc_list = task.result() @@ -367,7 +367,7 @@ def __init__(self, max_mutation_count: int, max_mutation_bytes: int): raise ValueError("max_mutation_count must be greater than 0") if self._max_mutation_bytes < 1: raise ValueError("max_mutation_bytes must be greater than 0") - self._capacity_condition = _CrossSync_Sync.Condition() + self._capacity_condition = CrossSync._Sync_Impl.Condition() self._in_flight_mutation_count = 0 self._in_flight_mutation_bytes = 0 diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 7074d628c..c34fb6209 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -316,8 +316,7 @@ def transform_class(in_obj: Type, **kwargs): transformer.visit(ast_tree) # find imports imports = transformer.get_imports(filename) - imports.add(ast.parse("from abc import ABC").body[0]) - imports.add(ast.parse("from google.cloud.bigtable.data._sync.cross_sync import _CrossSync_Sync").body[0]) + # imports.add(ast.parse("from abc import ABC").body[0]) # add locals from file, in case they are needed if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): with open(filename, "r") as f: @@ -361,7 +360,7 @@ def transform_class(in_obj: Type, **kwargs): combined_tree = ast.parse("") combined_imports = set() for async_class in class_map[output_file]: - text_replacements = {"CrossSync": "_CrossSync_Sync", **async_class.cross_sync_replace_symbols} + text_replacements = {"CrossSync": "CrossSync._Sync_Impl", **async_class.cross_sync_replace_symbols} tree_body, imports = transform_class(async_class, autogen_sync_name=async_class.cross_sync_class_name, text_replacements=text_replacements) # update combined data combined_tree.body.extend(tree_body) From 77165fd75b79a26e19f4790488f8afe23a52e081 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 16:51:22 -0700 Subject: [PATCH 072/360] fixed imports for mutate_rows --- .../bigtable/data/_async/_mutate_rows.py | 28 +++++++++++-------- .../cloud/bigtable/data/_sync/_mutate_rows.py | 20 +++++++------ 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index c576735b2..6cbc85483 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -21,10 +21,13 @@ from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries import google.cloud.bigtable_v2.types.bigtable as types_pb -import google.cloud.bigtable.data.exceptions as bt_exceptions from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup +from google.cloud.bigtable.data.exceptions import RetryExceptionGroup +from google.cloud.bigtable.data.exceptions import FailedMutationEntryError # mutate_rows requests are limited to this number of mutations from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT @@ -32,12 +35,15 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync if TYPE_CHECKING: - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) from google.cloud.bigtable.data.mutations import RowMutationEntry - from google.cloud.bigtable.data._async.client import TableAsync - + if CrossSync.is_async: + from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) + else: + from google.cloud.bigtable.data._sync.client import Table + from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient @dataclass class _EntryWithProto: @@ -107,7 +113,7 @@ def __init__( # RPC level errors *retryable_exceptions, # Entry level errors - bt_exceptions._MutateRowsIncomplete, + _MutateRowsIncomplete, ) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) # Note: _operation could be a raw coroutine, but using a lambda @@ -153,13 +159,13 @@ async def start(self): elif len(exc_list) == 1: cause_exc = exc_list[0] else: - cause_exc = bt_exceptions.RetryExceptionGroup(exc_list) + cause_exc = RetryExceptionGroup(exc_list) entry = self.mutations[idx].entry all_errors.append( - bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc) + FailedMutationEntryError(idx, entry, cause_exc) ) if all_errors: - raise bt_exceptions.MutationsExceptionGroup( + raise MutationsExceptionGroup( all_errors, len(self.mutations) ) @@ -215,7 +221,7 @@ async def _run_attempt(self): # check if attempt succeeded, or needs to be retried if self.remaining_indices: # unfinished work; raise exception to trigger retry - raise bt_exceptions._MutateRowsIncomplete + raise _MutateRowsIncomplete def _handle_entry_error(self, idx: int, exc: Exception): """ diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index fc0990e13..c368db229 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -25,9 +25,15 @@ from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data._sync.client import Table from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data.exceptions import FailedMutationEntryError +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup +from google.cloud.bigtable.data.exceptions import RetryExceptionGroup +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient class _MutateRowsOperation: @@ -73,7 +79,7 @@ def __init__( retry=None, ) self.is_retryable = retries.if_exception_type( - *retryable_exceptions, bt_exceptions._MutateRowsIncomplete + *retryable_exceptions, _MutateRowsIncomplete ) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) self._operation = lambda: CrossSync._Sync_Impl.retry_target( @@ -113,15 +119,11 @@ def start(self): elif len(exc_list) == 1: cause_exc = exc_list[0] else: - cause_exc = bt_exceptions.RetryExceptionGroup(exc_list) + cause_exc = RetryExceptionGroup(exc_list) entry = self.mutations[idx].entry - all_errors.append( - bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc) - ) + all_errors.append(FailedMutationEntryError(idx, entry, cause_exc)) if all_errors: - raise bt_exceptions.MutationsExceptionGroup( - all_errors, len(self.mutations) - ) + raise MutationsExceptionGroup(all_errors, len(self.mutations)) def _run_attempt(self): """ @@ -163,7 +165,7 @@ def _run_attempt(self): self._handle_entry_error(idx, exc) raise if self.remaining_indices: - raise bt_exceptions._MutateRowsIncomplete + raise _MutateRowsIncomplete def _handle_entry_error(self, idx: int, exc: Exception): """ From 9b048521c8d7595f0753c3761c587592089079e0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 16:55:31 -0700 Subject: [PATCH 073/360] fixed typing for _read_rows --- google/cloud/bigtable/data/_async/_read_rows.py | 8 ++++++-- google/cloud/bigtable/data/_sync/_read_rows.py | 10 ++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 987f0102b..3c431a0af 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -39,7 +39,11 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync if TYPE_CHECKING: - from google.cloud.bigtable.data._async.client import TableAsync + if CrossSync.is_async: + from google.cloud.bigtable.data._async.client import TableAsync + else: + from google.cloud.bigtable.data._sync.client import Table + from typing import Iterable class _ResetRow(Exception): @@ -49,7 +53,7 @@ def __init__(self, chunk): @CrossSync.sync_output("google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", replace_symbols={ - "AsyncIterable": "Iterable", "StopAsyncIteration": "StopIteration", "Awaitable": None, "TableAsync": "Table", + "AsyncIterable": "Iterable", "StopAsyncIteration": "StopIteration", "Awaitable": None, "TableAsync": "Table", "__aiter__": "__iter__", "__anext__": "__next__" } ) class _ReadRowsOperationAsync: diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index a5ee1d110..6225414ea 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -16,12 +16,14 @@ from __future__ import annotations +from typing import Iterable from typing import Sequence from google.api_core import retry as retries from google.api_core.retry import exponential_sleep_generator from google.cloud.bigtable.data import _helpers from google.cloud.bigtable.data._async._read_rows import _ResetRow +from google.cloud.bigtable.data._sync.client import Table from google.cloud.bigtable.data._sync.cross_sync import CrossSync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _RowSetComplete @@ -162,10 +164,10 @@ def merge_rows(chunks: Iterable[ReadRowsResponsePB.CellChunk] | None): """Merge chunks into rows""" if chunks is None: return - it = chunks.__aiter__() + it = chunks.__iter__() while True: try: - c = it.__anext__() + c = it.__next__() except StopIteration: return row_key = c.row_key @@ -203,7 +205,7 @@ def merge_rows(chunks: Iterable[ReadRowsResponsePB.CellChunk] | None): if c.value_size > 0: buffer = [value] while c.value_size > 0: - c = it.__anext__() + c = it.__next__() t = c.timestamp_micros cl = c.labels k = c.row_key @@ -233,7 +235,7 @@ def merge_rows(chunks: Iterable[ReadRowsResponsePB.CellChunk] | None): if c.commit_row: yield Row(row_key, cells) break - c = it.__anext__() + c = it.__next__() except _ResetRow as e: c = e.chunk if ( From d8542f4ee129b9b8aaf4944eed80e3762c31b277 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 20 Jun 2024 17:11:50 -0700 Subject: [PATCH 074/360] fixing typing issues --- google/cloud/bigtable/data/_async/client.py | 2 +- .../bigtable/data/_async/mutations_batcher.py | 40 +++++++++++++------ google/cloud/bigtable/data/_helpers.py | 7 +++- google/cloud/bigtable/data/_sync/client.py | 2 +- .../cloud/bigtable/data/_sync/cross_sync.py | 23 ++++++----- .../bigtable/data/_sync/mutations_batcher.py | 27 ++++++++----- 6 files changed, 65 insertions(+), 36 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 0d324952e..addc4dcdc 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -65,10 +65,10 @@ from google.cloud.bigtable.data import _helpers from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync -from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule from google.cloud.bigtable.data.row_filters import RowFilter from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index cda304756..50e399d9f 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -19,12 +19,14 @@ import atexit import warnings from collections import deque +import concurrent.futures from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync from google.cloud.bigtable.data.mutations import ( @@ -34,12 +36,15 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync + if TYPE_CHECKING: - from google.cloud.bigtable.data._async.client import TableAsync + if CrossSync.is_async: + from google.cloud.bigtable.data._async.client import TableAsync + else: + from google.cloud.bigtable.data._sync.client import Table + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation from google.cloud.bigtable.data.mutations import RowMutationEntry -# used to make more readable default values -_MB_SIZE = 1024 * 1024 @CrossSync.sync_output("google.cloud.bigtable.data._sync.mutations_batcher._FlowControl") @@ -167,7 +172,14 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] yield mutations[start_idx:end_idx] -@CrossSync.sync_output("google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher") +@CrossSync.sync_output( + "google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", + replace_symbols={ + "TableAsync": "Table", + "_FlowControlAsync": "_FlowControl", + "_MutateRowsOperationAsync": "_MutateRowsOperation", + }, +) class MutationsBatcherAsync: """ Allows users to send batches using context manager API: @@ -188,7 +200,7 @@ class MutationsBatcherAsync: def __init__( self, - table: CrossSync[TableAsync], + table: TableAsync, *, flush_interval: float | None = 5, flush_limit_mutation_count: int | None = 1000, @@ -229,7 +241,7 @@ def __init__( self._table = table self._staged_entries: list[RowMutationEntry] = [] self._staged_count, self._staged_bytes = 0, 0 - self._flow_control = CrossSync[_FlowControlAsync]( + self._flow_control = _FlowControlAsync( flow_control_max_mutation_count, flow_control_max_bytes ) self._flush_limit_bytes = flush_limit_bytes @@ -238,11 +250,6 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - # in sync mode, use a threadpool executor for background tasks - if not CrossSync.is_async: - self._sync_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) - else: - self._sync_executor = None self._flush_timer = CrossSync.create_task(self._timer_routine, flush_interval, sync_executor=self._sync_executor) self._flush_jobs: set[CrossSync.Future[None]] = set() # MutationExceptionGroup reports number of successful entries along with failures @@ -256,6 +263,15 @@ def __init__( ) # clean up on program exit atexit.register(self._on_exit) + + @property + def _sync_executor(self) -> concurrent.futures.ThreadPoolExecutor: + if CrossSync.is_async: + raise AttributeError("sync_executor is not available in async mode") + if not hasattr(self, "_sync_executor_instance"): + self._sync_executor_instance = concurrent.futures.ThreadPoolExecutor(max_workers=8) + return self._sync_executor_instance + async def _timer_routine(self, interval: float | None) -> None: """ Triggers new flush tasks every `interval` seconds @@ -344,7 +360,7 @@ async def _execute_mutate_rows( FailedMutationEntryError objects will not contain index information """ try: - operation = CrossSync[_MutateRowsOperationAsync]( + operation = _MutateRowsOperationAsync( self._table.client._gapic_client, self._table, batch, diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index a0b13cbaf..4e3f5ab70 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -29,6 +29,7 @@ if TYPE_CHECKING: import grpc from google.cloud.bigtable.data import TableAsync + from google.cloud.bigtable.data._sync.client import Table """ Helper functions used in various places in the library. @@ -48,6 +49,8 @@ "_WarmedInstanceKey", ["instance_name", "table_name", "app_profile_id"] ) +# used to make more readable default values +_MB_SIZE = 1024 * 1024 # enum used on method calls when table defaults should be used class TABLE_DEFAULT(enum.Enum): @@ -133,7 +136,7 @@ def _retry_exception_factory( def _get_timeouts( operation: float | TABLE_DEFAULT, attempt: float | None | TABLE_DEFAULT, - table: "TableAsync", + table: "TableAsync" | "Table", ) -> tuple[float, float]: """ Convert passed in timeout values to floats, using table defaults if necessary. @@ -204,7 +207,7 @@ def _validate_timeouts( def _get_retryable_errors( call_codes: Sequence["grpc.StatusCode" | int | type[Exception]] | TABLE_DEFAULT, - table: "TableAsync", + table: "TableAsync" | "Table", ) -> list[type[Exception]]: # load table defaults if necessary if call_codes == TABLE_DEFAULT.DEFAULT: diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 220eb130d..2c64d7df5 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -39,10 +39,10 @@ from google.cloud.bigtable.data import _helpers from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync -from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data._sync.cross_sync import CrossSync from google.cloud.bigtable.data.exceptions import FailedQueryShardError from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 224845678..db0ba7d96 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing_extensions import TypeAlias import asyncio import sys @@ -39,13 +40,14 @@ class CrossSync(metaclass=_AsyncGetAttr): is_async = True sleep = asyncio.sleep - Queue = asyncio.Queue - Condition = asyncio.Condition - Future = asyncio.Future - Task = asyncio.Task - Event = asyncio.Event retry_target = retries.retry_target_async retry_target_stream = retries.retry_target_stream_async + Queue: TypeAlias = asyncio.Queue + Condition: TypeAlias = asyncio.Condition + Future: TypeAlias = asyncio.Future + Task: TypeAlias = asyncio.Task + Event: TypeAlias = asyncio.Event + generated_replacements = {} @staticmethod @@ -133,13 +135,14 @@ class _Sync_Impl(metaclass=_SyncGetAttr): is_async = False sleep = time.sleep - Queue = queue.Queue - Condition = threading.Condition - Future = concurrent.futures.Future - Task = concurrent.futures.Future - Event = threading.Event retry_target = retries.retry_target retry_target_stream = retries.retry_target_stream + Queue: TypeAlias = queue.Queue + Condition: TypeAlias = threading.Condition + Future:TypeAlias = concurrent.futures.Future + Task:TypeAlias = concurrent.futures.Future + Event:TypeAlias = threading.Event + generated_replacements = {} @staticmethod diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 4a798af3f..8156b3559 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -20,14 +20,15 @@ from typing import Sequence import asyncio import atexit +import concurrent.futures import warnings -from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync -from google.cloud.bigtable.data._async.client import TableAsync -from google.cloud.bigtable.data._async.mutations_batcher import _FlowControlAsync from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation +from google.cloud.bigtable.data._sync.client import Table from google.cloud.bigtable.data._sync.cross_sync import CrossSync from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup @@ -56,7 +57,7 @@ class MutationsBatcher: def __init__( self, - table: CrossSync._Sync_Impl[TableAsync], + table: Table, *, flush_interval: float | None = 5, flush_limit_mutation_count: int | None = 1000, @@ -96,7 +97,7 @@ def __init__( self._table = table self._staged_entries: list[RowMutationEntry] = [] self._staged_count, self._staged_bytes = (0, 0) - self._flow_control = CrossSync._Sync_Impl[_FlowControlAsync]( + self._flow_control = _FlowControl( flow_control_max_mutation_count, flow_control_max_bytes ) self._flush_limit_bytes = flush_limit_bytes @@ -105,10 +106,6 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - if not CrossSync._Sync_Impl.is_async: - self._sync_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8) - else: - self._sync_executor = None self._flush_timer = CrossSync._Sync_Impl.create_task( self._timer_routine, flush_interval, sync_executor=self._sync_executor ) @@ -122,6 +119,16 @@ def __init__( ) atexit.register(self._on_exit) + @property + def _sync_executor(self) -> concurrent.futures.ThreadPoolExecutor: + if CrossSync._Sync_Impl.is_async: + raise AttributeError("sync_executor is not available in async mode") + if not hasattr(self, "_sync_executor_instance"): + self._sync_executor_instance = concurrent.futures.ThreadPoolExecutor( + max_workers=8 + ) + return self._sync_executor_instance + def _timer_routine(self, interval: float | None) -> None: """ Triggers new flush tasks every `interval` seconds @@ -210,7 +217,7 @@ def _execute_mutate_rows( FailedMutationEntryError objects will not contain index information """ try: - operation = CrossSync._Sync_Impl[_MutateRowsOperationAsync]( + operation = _MutateRowsOperation( self._table.client._gapic_client, self._table, batch, From 975ed1234775ee904fb45d031fbc873d8d2e483b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 11:10:42 -0700 Subject: [PATCH 075/360] fixing mypy issues --- google/cloud/bigtable/data/_async/client.py | 42 ++++++++++++------- google/cloud/bigtable/data/_sync/client.py | 23 ++++++---- .../cloud/bigtable/data/_sync/cross_sync.py | 6 ++- 3 files changed, 48 insertions(+), 23 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index addc4dcdc..165266703 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -37,14 +37,7 @@ from grpc import Channel from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient -from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, -) -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel as AsyncPooledChannel, -) +from google.cloud.bigtable_v2.services.bigtable.transports.base import DEFAULT_CLIENT_INFO from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest from google.cloud.client import ClientWithProject from google.cloud.environment_vars import BIGTABLE_EMULATOR # type: ignore @@ -52,7 +45,6 @@ from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import ServiceUnavailable from google.api_core.exceptions import Aborted -from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync import google.auth.credentials import google.auth._default @@ -67,8 +59,7 @@ from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry -from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync -from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync + from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule from google.cloud.bigtable.data.row_filters import RowFilter from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter @@ -77,6 +68,29 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if CrossSync.is_async: + from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync + from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel as AsyncPooledChannel, + ) + from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient +else: + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation + from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledBigtableGrpcTransport, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledChannel, + ) + from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient + from typing import Iterable if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples @@ -87,7 +101,7 @@ "google.cloud.bigtable.data._sync.client.BigtableDataClient", replace_symbols={ "__aenter__": "__enter__", "__aexit__": "__exit__", - "TableAsync": "Table", "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcIOTransport", + "TableAsync": "Table", "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", "BigtableAsyncClient": "BigtableClient", "AsyncPooledChannel": "PooledChannel" } ) @@ -946,7 +960,7 @@ async def execute_rpc(): ) return [(s.row_key, s.offset_bytes) async for s in results] - return await retries.retry_target_async( + return await CrossSync.retry_target( execute_rpc, predicate, sleep_generator, @@ -1079,7 +1093,7 @@ async def mutate_row( metadata=_helpers._make_metadata(self.table_name, self.app_profile_id), retry=None, ) - return await retries.retry_target_async( + return await CrossSync.retry_target( target, predicate, sleep_generator, diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 2c64d7df5..ae74124dc 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -19,6 +19,7 @@ from functools import partial from grpc import Channel from typing import Any +from typing import Iterable from typing import Optional from typing import Sequence from typing import Set @@ -44,6 +45,7 @@ from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher from google.cloud.bigtable.data.exceptions import FailedQueryShardError from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.cloud.bigtable.data.mutations import Mutation @@ -55,8 +57,17 @@ from google.cloud.bigtable.data.row_filters import RowFilter from google.cloud.bigtable.data.row_filters import RowFilterChain from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter -from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO +from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta +from google.cloud.bigtable_v2.services.bigtable.transports.base import ( + DEFAULT_CLIENT_INFO, +) +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledBigtableGrpcTransport, +) +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledChannel, +) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest from google.cloud.client import ClientWithProject from google.cloud.environment_vars import BIGTABLE_EMULATOR @@ -102,7 +113,7 @@ def __init__( - ValueError if pool_size is less than 1 """ transport_str = f"bt-{self._client_version()}-{pool_size}" - transport = PooledBigtableGrpcIOTransport.with_fixed_size(pool_size) + transport = PooledBigtableGrpcTransport.with_fixed_size(pool_size) BigtableClientMeta._transport_registry[transport_str] = transport client_info = DEFAULT_CLIENT_INFO client_info.client_library_version = self._client_version() @@ -130,9 +141,7 @@ def __init__( client_info=client_info, ) self._is_closed = asyncio.Event() - self.transport = cast( - PooledBigtableGrpcIOTransport, self._gapic_client.transport - ) + self.transport = cast(PooledBigtableGrpcTransport, self._gapic_client.transport) self._active_instances: Set[_helpers._WarmedInstanceKey] = set() self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() @@ -877,7 +886,7 @@ def execute_rpc(): ) return [(s.row_key, s.offset_bytes) for s in results] - return retries.retry_target_async( + return CrossSync._Sync_Impl.retry_target( execute_rpc, predicate, sleep_generator, @@ -1004,7 +1013,7 @@ def mutate_row( metadata=_helpers._make_metadata(self.table_name, self.app_profile_id), retry=None, ) - return retries.retry_target_async( + return CrossSync._Sync_Impl.retry_target( target, predicate, sleep_generator, diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index db0ba7d96..29a20e279 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations + from typing_extensions import TypeAlias import asyncio @@ -48,7 +50,7 @@ class CrossSync(metaclass=_AsyncGetAttr): Task: TypeAlias = asyncio.Task Event: TypeAlias = asyncio.Event - generated_replacements = {} + generated_replacements: dict[type, str] = {} @staticmethod def rename_sync(*args, **kwargs): @@ -143,7 +145,7 @@ class _Sync_Impl(metaclass=_SyncGetAttr): Task:TypeAlias = concurrent.futures.Future Event:TypeAlias = threading.Event - generated_replacements = {} + generated_replacements: dict[type, str] = {} @staticmethod def wait(futures, timeout=None): From 2dc3de48eef00121a97688c0f288abc64f7e106d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 11:13:36 -0700 Subject: [PATCH 076/360] ran blacken --- .../bigtable/data/_async/_mutate_rows.py | 12 ++--- .../cloud/bigtable/data/_async/_read_rows.py | 12 +++-- google/cloud/bigtable/data/_async/client.py | 51 +++++++++++++------ .../bigtable/data/_async/mutations_batcher.py | 27 ++++++---- google/cloud/bigtable/data/_helpers.py | 1 + .../cloud/bigtable/data/_sync/cross_sync.py | 32 ++++++------ 6 files changed, 84 insertions(+), 51 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 6cbc85483..5978fcc98 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -36,6 +36,7 @@ if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry + if CrossSync.is_async: from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable_v2.services.bigtable.async_client import ( @@ -45,6 +46,7 @@ from google.cloud.bigtable.data._sync.client import Table from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient + @dataclass class _EntryWithProto: """ @@ -60,7 +62,7 @@ class _EntryWithProto: replace_symbols={ "BigtableAsyncClient": "BigtableClient", "TableAsync": "Table", - } + }, ) class _MutateRowsOperationAsync: """ @@ -161,13 +163,9 @@ async def start(self): else: cause_exc = RetryExceptionGroup(exc_list) entry = self.mutations[idx].entry - all_errors.append( - FailedMutationEntryError(idx, entry, cause_exc) - ) + all_errors.append(FailedMutationEntryError(idx, entry, cause_exc)) if all_errors: - raise MutationsExceptionGroup( - all_errors, len(self.mutations) - ) + raise MutationsExceptionGroup(all_errors, len(self.mutations)) async def _run_attempt(self): """ diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 3c431a0af..fb1f67add 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -51,10 +51,16 @@ def __init__(self, chunk): self.chunk = chunk -@CrossSync.sync_output("google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", +@CrossSync.sync_output( + "google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", replace_symbols={ - "AsyncIterable": "Iterable", "StopAsyncIteration": "StopIteration", "Awaitable": None, "TableAsync": "Table", "__aiter__": "__iter__", "__anext__": "__next__" - } + "AsyncIterable": "Iterable", + "StopAsyncIteration": "StopIteration", + "Awaitable": None, + "TableAsync": "Table", + "__aiter__": "__iter__", + "__anext__": "__next__", + }, ) class _ReadRowsOperationAsync: """ diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 165266703..c6cef8fee 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -37,7 +37,9 @@ from grpc import Channel from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.transports.base import DEFAULT_CLIENT_INFO +from google.cloud.bigtable_v2.services.bigtable.transports.base import ( + DEFAULT_CLIENT_INFO, +) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest from google.cloud.client import ClientWithProject from google.cloud.environment_vars import BIGTABLE_EMULATOR # type: ignore @@ -70,7 +72,9 @@ if CrossSync.is_async: from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync - from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( PooledBigtableGrpcAsyncIOTransport, @@ -78,7 +82,9 @@ from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( PooledChannel as AsyncPooledChannel, ) - from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) else: from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher @@ -100,13 +106,15 @@ @CrossSync.sync_output( "google.cloud.bigtable.data._sync.client.BigtableDataClient", replace_symbols={ - "__aenter__": "__enter__", "__aexit__": "__exit__", - "TableAsync": "Table", "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", - "BigtableAsyncClient": "BigtableClient", "AsyncPooledChannel": "PooledChannel" - } + "__aenter__": "__enter__", + "__aexit__": "__exit__", + "TableAsync": "Table", + "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", + "BigtableAsyncClient": "BigtableClient", + "AsyncPooledChannel": "PooledChannel", + }, ) class BigtableDataClientAsync(ClientWithProject): - def __init__( self, *, @@ -188,7 +196,9 @@ def __init__( self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() self._channel_refresh_tasks: list[asyncio.Task[None]] = [] - self._executor = concurrent.futures.ThreadPoolExecutor() if not CrossSync.is_async else None + self._executor = ( + concurrent.futures.ThreadPoolExecutor() if not CrossSync.is_async else None + ) if self._emulator_host is not None: # connect to an emulator host warnings.warn( @@ -239,12 +249,15 @@ def _start_background_channel_refresh(self) -> None: ): for channel_idx in range(self.transport.pool_size): refresh_task = CrossSync.create_task( - self._manage_channel, channel_idx, + self._manage_channel, + channel_idx, sync_executor=self._executor, - task_name=f"{self.__class__.__name__} channel refresh {channel_idx}" + task_name=f"{self.__class__.__name__} channel refresh {channel_idx}", ) self._channel_refresh_tasks.append(refresh_task) - refresh_task.add_done_callback(lambda _: self._channel_refresh_tasks.remove(refresh_task)) + refresh_task.add_done_callback( + lambda _: self._channel_refresh_tasks.remove(refresh_task) + ) async def close(self, timeout: float | None = None): """ @@ -294,7 +307,9 @@ async def _ping_and_warm_instances( ) for (instance_name, table_name, app_profile_id) in instance_list ] - result_list = await CrossSync.gather_partials(partial_list, return_exceptions=True, sync_executor=self._executor) + result_list = await CrossSync.gather_partials( + partial_list, return_exceptions=True, sync_executor=self._executor + ) return [r or None for r in result_list] async def _manage_channel( @@ -464,7 +479,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): "AsyncIterable": "Iterable", "MutationsBatcherAsync": "MutationsBatcher", "BigtableDataClientAsync": "BigtableDataClient", - } + }, ) class TableAsync: """ @@ -584,7 +599,9 @@ def __init__( self.default_retryable_errors = default_retryable_errors or () try: self._register_instance_future = CrossSync.create_task( - self.client._register_instance, self.instance_id, self, + self.client._register_instance, + self.instance_id, + self, sync_executor=self.client._executor, ) except RuntimeError as e: @@ -819,7 +836,9 @@ async def read_rows_sharded( for query in batch ] batch_result = await CrossSync.gather_partials( - batch_partial_list, return_exceptions=True, sync_executor=self.client._executor + batch_partial_list, + return_exceptions=True, + sync_executor=self.client._executor, ) for result in batch_result: if isinstance(result, Exception): diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 50e399d9f..f5a6c86cc 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -46,8 +46,9 @@ from google.cloud.bigtable.data.mutations import RowMutationEntry - -@CrossSync.sync_output("google.cloud.bigtable.data._sync.mutations_batcher._FlowControl") +@CrossSync.sync_output( + "google.cloud.bigtable.data._sync.mutations_batcher._FlowControl" +) class _FlowControlAsync: """ Manages flow control for batched mutations. Mutations are registered against @@ -149,7 +150,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] while end_idx < len(mutations): next_entry = mutations[end_idx] next_size = next_entry.size() - next_count= len(next_entry.mutations) + next_count = len(next_entry.mutations) if ( self._has_capacity(next_count, next_size) # make sure not to exceed per-request mutation count limits @@ -250,7 +251,9 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - self._flush_timer = CrossSync.create_task(self._timer_routine, flush_interval, sync_executor=self._sync_executor) + self._flush_timer = CrossSync.create_task( + self._timer_routine, flush_interval, sync_executor=self._sync_executor + ) self._flush_jobs: set[CrossSync.Future[None]] = set() # MutationExceptionGroup reports number of successful entries along with failures self._entries_processed_since_last_raise: int = 0 @@ -269,7 +272,9 @@ def _sync_executor(self) -> concurrent.futures.ThreadPoolExecutor: if CrossSync.is_async: raise AttributeError("sync_executor is not available in async mode") if not hasattr(self, "_sync_executor_instance"): - self._sync_executor_instance = concurrent.futures.ThreadPoolExecutor(max_workers=8) + self._sync_executor_instance = concurrent.futures.ThreadPoolExecutor( + max_workers=8 + ) return self._sync_executor_instance async def _timer_routine(self, interval: float | None) -> None: @@ -320,7 +325,9 @@ def _schedule_flush(self) -> CrossSync.Future[None] | None: if self._staged_entries: entries, self._staged_entries = self._staged_entries, [] self._staged_count, self._staged_bytes = 0, 0 - new_task = CrossSync.create_task(self._flush_internal, entries, sync_executor=self._sync_executor) + new_task = CrossSync.create_task( + self._flush_internal, entries, sync_executor=self._sync_executor + ) if not new_task.done(): self._flush_jobs.add(new_task) new_task.add_done_callback(self._flush_jobs.remove) @@ -337,7 +344,9 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): # flush new entries in_process_requests: list[CrossSync.Future[list[FailedMutationEntryError]]] = [] async for batch in self._flow_control.add_to_flow(new_entries): - batch_task = CrossSync.create_task(self._execute_mutate_rows, batch, sync_executor=self._sync_executor) + batch_task = CrossSync.create_task( + self._execute_mutate_rows, batch, sync_executor=self._sync_executor + ) in_process_requests.append(batch_task) # wait for all inflight requests to complete found_exceptions = await self._wait_for_batch_results(*in_process_requests) @@ -472,10 +481,10 @@ def _on_exit(self): f"{len(self._staged_entries)} Unflushed mutations will not be sent to the server." ) - @staticmethod async def _wait_for_batch_results( - *tasks: CrossSync.Future[list[FailedMutationEntryError]] | CrossSync.Future[None], + *tasks: CrossSync.Future[list[FailedMutationEntryError]] + | CrossSync.Future[None], ) -> list[Exception]: """ Takes in a list of futures representing _execute_mutate_rows tasks, diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index 4e3f5ab70..fb51df36a 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -52,6 +52,7 @@ # used to make more readable default values _MB_SIZE = 1024 * 1024 + # enum used on method calls when table defaults should be used class TABLE_DEFAULT(enum.Enum): # default for mutate_row, sample_row_keys, check_and_mutate_row, and read_modify_write_row diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 29a20e279..1934d4490 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -26,18 +26,17 @@ class _AsyncGetAttr(type): - def __getitem__(cls, item): return item -class _SyncGetAttr(type): +class _SyncGetAttr(type): def __getitem__(cls, item): breakpoint() return CrossSync.generated_replacements[item] -class CrossSync(metaclass=_AsyncGetAttr): +class CrossSync(metaclass=_AsyncGetAttr): SyncImports = False is_async = True @@ -56,24 +55,29 @@ class CrossSync(metaclass=_AsyncGetAttr): def rename_sync(*args, **kwargs): def decorator(func): return func + return decorator @classmethod def sync_output(cls, sync_path, replace_symbols=None): replace_symbols = replace_symbols or {} + # return the async class unchanged def decorator(async_cls): cls.generated_replacements[async_cls] = sync_path async_cls.cross_sync_enabled = True async_cls.cross_sync_import_path = sync_path - async_cls.cross_sync_class_name = sync_path.rsplit('.', 1)[-1] + async_cls.cross_sync_class_name = sync_path.rsplit(".", 1)[-1] async_cls.cross_sync_file_path = "/".join(sync_path.split(".")[:-1]) + ".py" async_cls.cross_sync_replace_symbols = replace_symbols return async_cls + return decorator @staticmethod - async def gather_partials(partial_list, return_exceptions=False, sync_executor=None): + async def gather_partials( + partial_list, return_exceptions=False, sync_executor=None + ): """ abstraction over asyncio.gather @@ -88,7 +92,9 @@ async def gather_partials(partial_list, return_exceptions=False, sync_executor=N if not partial_list: return [] awaitable_list = [partial() for partial in partial_list] - return await asyncio.gather(*awaitable_list, return_exceptions=return_exceptions) + return await asyncio.gather( + *awaitable_list, return_exceptions=return_exceptions + ) @staticmethod async def wait(futures, timeout=None): @@ -131,9 +137,7 @@ async def yield_to_event_loop(): """ await asyncio.sleep(0) - class _Sync_Impl(metaclass=_SyncGetAttr): - is_async = False sleep = time.sleep @@ -141,9 +145,9 @@ class _Sync_Impl(metaclass=_SyncGetAttr): retry_target_stream = retries.retry_target_stream Queue: TypeAlias = queue.Queue Condition: TypeAlias = threading.Condition - Future:TypeAlias = concurrent.futures.Future - Task:TypeAlias = concurrent.futures.Future - Event:TypeAlias = threading.Event + Future: TypeAlias = concurrent.futures.Future + Task: TypeAlias = concurrent.futures.Future + Event: TypeAlias = threading.Event generated_replacements: dict[type, str] = {} @@ -169,9 +173,7 @@ def gather_partials(partial_list, return_exceptions=False, sync_executor=None): return [] if not sync_executor: raise ValueError("sync_executor is required for sync version") - futures_list = [ - sync_executor.submit(partial) for partial in partial_list - ] + futures_list = [sync_executor.submit(partial) for partial in partial_list] results_list = [] for future in futures_list: if future.exception(): @@ -183,8 +185,6 @@ def gather_partials(partial_list, return_exceptions=False, sync_executor=None): results_list.append(future.result()) return results_list - - @staticmethod def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): """ From 9bb7f6ea61ad8f1cb05e5e054b320cc2f4922cfb Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 11:20:51 -0700 Subject: [PATCH 077/360] fixed executor in mutations batcher --- .../bigtable/data/_async/mutations_batcher.py | 19 +++++++------------ .../bigtable/data/_sync/mutations_batcher.py | 17 ++++++----------- 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index f5a6c86cc..8eb502932 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -251,6 +251,11 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) + self._sync_executor = ( + concurrent.futures.ThreadPoolExecutor(max_workers=8) + if not CrossSync.is_async + else None + ) self._flush_timer = CrossSync.create_task( self._timer_routine, flush_interval, sync_executor=self._sync_executor ) @@ -267,16 +272,6 @@ def __init__( # clean up on program exit atexit.register(self._on_exit) - @property - def _sync_executor(self) -> concurrent.futures.ThreadPoolExecutor: - if CrossSync.is_async: - raise AttributeError("sync_executor is not available in async mode") - if not hasattr(self, "_sync_executor_instance"): - self._sync_executor_instance = concurrent.futures.ThreadPoolExecutor( - max_workers=8 - ) - return self._sync_executor_instance - async def _timer_routine(self, interval: float | None) -> None: """ Triggers new flush tasks every `interval` seconds @@ -463,8 +458,8 @@ async def close(self): await self._flush_timer except asyncio.CancelledError: pass - else: - # shut down executor + # shut down executor + if self._sync_executor: with self._sync_executor: self._sync_executor.shutdown(wait=True) atexit.unregister(self._on_exit) diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 8156b3559..fcdf263e8 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -106,6 +106,11 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) + self._sync_executor = ( + concurrent.futures.ThreadPoolExecutor(max_workers=8) + if not CrossSync._Sync_Impl.is_async + else None + ) self._flush_timer = CrossSync._Sync_Impl.create_task( self._timer_routine, flush_interval, sync_executor=self._sync_executor ) @@ -119,16 +124,6 @@ def __init__( ) atexit.register(self._on_exit) - @property - def _sync_executor(self) -> concurrent.futures.ThreadPoolExecutor: - if CrossSync._Sync_Impl.is_async: - raise AttributeError("sync_executor is not available in async mode") - if not hasattr(self, "_sync_executor_instance"): - self._sync_executor_instance = concurrent.futures.ThreadPoolExecutor( - max_workers=8 - ) - return self._sync_executor_instance - def _timer_routine(self, interval: float | None) -> None: """ Triggers new flush tasks every `interval` seconds @@ -302,7 +297,7 @@ def close(self): self._flush_timer except asyncio.CancelledError: pass - else: + if self._sync_executor: with self._sync_executor: self._sync_executor.shutdown(wait=True) atexit.unregister(self._on_exit) From cc03c157753f42b3a8b85edfabe7f7bc878e42b7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 11:44:37 -0700 Subject: [PATCH 078/360] improved mutations batcher close --- google/cloud/bigtable/data/_async/mutations_batcher.py | 10 ++-------- google/cloud/bigtable/data/_sync/mutations_batcher.py | 10 ++-------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 8eb502932..9010604ce 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -450,14 +450,8 @@ async def close(self): self._closed.set() self._flush_timer.cancel() self._schedule_flush() - if CrossSync.is_async: - # flush remaining tasks - if self._flush_jobs: - await asyncio.gather(*self._flush_jobs, return_exceptions=True) - try: - await self._flush_timer - except asyncio.CancelledError: - pass + await CrossSync.wait(self._flush_jobs) + await CrossSync.wait(self._flush_timer) # shut down executor if self._sync_executor: with self._sync_executor: diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index fcdf263e8..531499609 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -18,7 +18,6 @@ from __future__ import annotations from collections import deque from typing import Sequence -import asyncio import atexit import concurrent.futures import warnings @@ -290,13 +289,8 @@ def close(self): self._closed.set() self._flush_timer.cancel() self._schedule_flush() - if CrossSync._Sync_Impl.is_async: - if self._flush_jobs: - asyncio.gather(*self._flush_jobs, return_exceptions=True) - try: - self._flush_timer - except asyncio.CancelledError: - pass + CrossSync._Sync_Impl.wait(self._flush_jobs) + CrossSync._Sync_Impl.wait(self._flush_timer) if self._sync_executor: with self._sync_executor: self._sync_executor.shutdown(wait=True) From c71e5a924cb9f58cd30e37bf99de7d68d96f5135 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 11:56:42 -0700 Subject: [PATCH 079/360] updated docstring --- google/cloud/bigtable/data/_sync/cross_sync.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 1934d4490..488d68d0b 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -79,7 +79,9 @@ async def gather_partials( partial_list, return_exceptions=False, sync_executor=None ): """ - abstraction over asyncio.gather + abstraction over asyncio.gather, but with a set of partial functions instead + of coroutines, to work with sync functions. + To use gather with a set of futures instead of partials, use CrpssSync.wait In the async version, the partials are expected to return an awaitable object. Patials are unpacked and awaited in the gather call. From 0500969702a60aec55ab05cf0a78316f488c2e44 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 12:17:57 -0700 Subject: [PATCH 080/360] fixed event wait --- google/cloud/bigtable/data/_async/client.py | 3 +-- google/cloud/bigtable/data/_sync/client.py | 7 +++-- .../cloud/bigtable/data/_sync/cross_sync.py | 26 +++++++++++++++++++ 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index c6cef8fee..cad16218e 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -338,7 +338,6 @@ async def _manage_channel( grace_period: time to allow previous channel to serve existing requests before closing, in seconds """ - sleep_fn = asyncio.sleep if CrossSync.is_async else self._is_closed.wait first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) @@ -349,7 +348,7 @@ async def _manage_channel( await self._ping_and_warm_instances(channel) # continuously refresh the channel every `refresh_interval` seconds while not self._is_closed.is_set(): - await sleep_fn(next_sleep) + await CrossSync.event_wait(self._is_closed, next_sleep, async_break_early=False) if self._is_closed.is_set(): break # prepare new channel for use diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index ae74124dc..bc8ec800f 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -279,9 +279,6 @@ def _manage_channel( grace_period: time to allow previous channel to serve existing requests before closing, in seconds """ - sleep_fn = ( - asyncio.sleep if CrossSync._Sync_Impl.is_async else self._is_closed.wait - ) first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) @@ -290,7 +287,9 @@ def _manage_channel( channel = self.transport.channels[channel_idx] self._ping_and_warm_instances(channel) while not self._is_closed.is_set(): - sleep_fn(next_sleep) + CrossSync._Sync_Impl.event_wait( + self._is_closed, next_sleep, async_break_early=False + ) if self._is_closed.is_set(): break new_channel = self.transport.grpc_channel._create_channel() diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 488d68d0b..c0b2c5e3a 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -120,6 +120,28 @@ async def condition_wait(condition, timeout=None): except asyncio.TimeoutError: return False + @staticmethod + async def event_wait(event, timeout=None, async_break_early=True): + """ + abstraction over asyncio.Event.wait + + Args: + - event: event to wait for + - timeout: if set, will break out early after `timeout` seconds + - async_break_early: if False, the async version will wait for + the full timeout even if the event is set before the timeout. + This avoids creating a new background task + """ + if timeout is None: + await event.wait() + elif not async_break_early: + await asyncio.sleep(timeout) + else: + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + pass + @staticmethod def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): """ @@ -169,6 +191,10 @@ def condition_wait(condition, timeout=None): """ return condition.wait(timeout=timeout) + @staticmethod + def event_wait(event, timeout=None, async_break_early=True): + event.wait(timeout=timeout) + @staticmethod def gather_partials(partial_list, return_exceptions=False, sync_executor=None): if not partial_list: From 5cd4528da356dfae4fc65be06233412776672325 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 13:04:10 -0700 Subject: [PATCH 081/360] added way to ignore mypy errors on generated files --- google/cloud/bigtable/data/_async/mutations_batcher.py | 1 + google/cloud/bigtable/data/_sync/cross_sync.py | 4 +++- google/cloud/bigtable/data/_sync/mutations_batcher.py | 2 +- sync_surface_generator.py | 7 ++++++- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 9010604ce..a203a0027 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -180,6 +180,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] "_FlowControlAsync": "_FlowControl", "_MutateRowsOperationAsync": "_MutateRowsOperation", }, + mypy_ignore=["unreachable"], ) class MutationsBatcherAsync: """ diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index c0b2c5e3a..33cfbf12a 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -59,8 +59,9 @@ def decorator(func): return decorator @classmethod - def sync_output(cls, sync_path, replace_symbols=None): + def sync_output(cls, sync_path, replace_symbols=None, mypy_ignore=None): replace_symbols = replace_symbols or {} + mypy_ignore = mypy_ignore or [] # return the async class unchanged def decorator(async_cls): @@ -70,6 +71,7 @@ def decorator(async_cls): async_cls.cross_sync_class_name = sync_path.rsplit(".", 1)[-1] async_cls.cross_sync_file_path = "/".join(sync_path.split(".")[:-1]) + ".py" async_cls.cross_sync_replace_symbols = replace_symbols + async_cls.cross_sync_mypy_ignore = mypy_ignore return async_cls return decorator diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 531499609..d4e6c7f0c 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -13,7 +13,7 @@ # limitations under the License. # # This file is automatically generated by sync_surface_generator.py. Do not edit. - +# mypy: disable-error-code="unreachable" from __future__ import annotations from collections import deque diff --git a/sync_surface_generator.py b/sync_surface_generator.py index c34fb6209..efd964ab4 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -357,10 +357,12 @@ def transform_class(in_obj: Type, **kwargs): # generate sync code for each class for output_file in class_map.keys(): # initialize new tree and import list + file_mypy_ignore = set() combined_tree = ast.parse("") combined_imports = set() for async_class in class_map[output_file]: text_replacements = {"CrossSync": "CrossSync._Sync_Impl", **async_class.cross_sync_replace_symbols} + file_mypy_ignore.update(async_class.cross_sync_mypy_ignore) tree_body, imports = transform_class(async_class, autogen_sync_name=async_class.cross_sync_class_name, text_replacements=text_replacements) # update combined data combined_tree.body.extend(tree_body) @@ -375,6 +377,9 @@ def transform_class(in_obj: Type, **kwargs): else: non_google.append(i) import_str = "\n".join(non_google + [""] + google) + mypy_ignore_str = ", ".join(file_mypy_ignore) + if mypy_ignore_str: + mypy_ignore_str = f"# mypy: disable-error-code=\"{mypy_ignore_str}\"" # append clean tree header = """# Copyright 2024 Google LLC # @@ -392,7 +397,7 @@ def transform_class(in_obj: Type, **kwargs): # # This file is automatically generated by sync_surface_generator.py. Do not edit. """ - full_code = f"{header}\n\n{import_str}\n\n{ast.unparse(combined_tree)}" + full_code = f"{header}{mypy_ignore_str}\n\n{import_str}\n\n{ast.unparse(combined_tree)}" full_code = autoflake.fix_code(full_code, remove_all_unused_imports=True) formatted_code = format_str(full_code, mode=FileMode()) print(f"saving {[c.cross_sync_class_name for c in class_map[output_file]]} to {output_file}...") From c6840b8851ddb5642403abba6331650f71fc2f15 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 13:13:12 -0700 Subject: [PATCH 082/360] fixed import lint --- .../bigtable/data/_async/_mutate_rows.py | 6 +++-- .../cloud/bigtable/data/_async/_read_rows.py | 4 +-- google/cloud/bigtable/data/_async/client.py | 27 ++++++++++++------- .../bigtable/data/_async/mutations_batcher.py | 12 +++++---- .../cloud/bigtable/data/_sync/cross_sync.py | 2 +- 5 files changed, 32 insertions(+), 19 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 5978fcc98..9b04b362e 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -43,8 +43,10 @@ BigtableAsyncClient, ) else: - from google.cloud.bigtable.data._sync.client import Table - from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient + from google.cloud.bigtable.data._sync.client import Table # noqa: F401 + from google.cloud.bigtable_v2.services.bigtable.client import ( # noqa: F401 + BigtableClient, + ) @dataclass diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index fb1f67add..83673458b 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -42,8 +42,8 @@ if CrossSync.is_async: from google.cloud.bigtable.data._async.client import TableAsync else: - from google.cloud.bigtable.data._sync.client import Table - from typing import Iterable + from google.cloud.bigtable.data._sync.client import Table # noqa: F401 + from typing import Iterable # noqa: F401 class _ResetRow(Exception): diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index cad16218e..ed9fb8035 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -28,7 +28,6 @@ import asyncio import time import warnings -import sys import random import os import concurrent.futures @@ -86,17 +85,25 @@ BigtableAsyncClient, ) else: - from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation - from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + from google.cloud.bigtable.data._sync._mutate_rows import ( # noqa: F401 + _MutateRowsOperation, + ) + from google.cloud.bigtable.data._sync.mutations_batcher import ( # noqa: F401 + MutationsBatcher, + ) + from google.cloud.bigtable.data._sync._read_rows import ( # noqa: F401 + _ReadRowsOperation, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( # noqa: F401 PooledBigtableGrpcTransport, ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( # noqa: F401 PooledChannel, ) - from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient - from typing import Iterable + from google.cloud.bigtable_v2.services.bigtable.client import ( # noqa: F401 + BigtableClient, + ) + from typing import Iterable # noqa: F401 if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples @@ -348,7 +355,9 @@ async def _manage_channel( await self._ping_and_warm_instances(channel) # continuously refresh the channel every `refresh_interval` seconds while not self._is_closed.is_set(): - await CrossSync.event_wait(self._is_closed, next_sleep, async_break_early=False) + await CrossSync.event_wait( + self._is_closed, next_sleep, async_break_early=False + ) if self._is_closed.is_set(): break # prepare new channel for use diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index a203a0027..b506da032 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -14,8 +14,7 @@ # from __future__ import annotations -from typing import Any, Sequence, TYPE_CHECKING -import asyncio +from typing import Sequence, TYPE_CHECKING import atexit import warnings from collections import deque @@ -38,12 +37,15 @@ if TYPE_CHECKING: + from google.cloud.bigtable.data.mutations import RowMutationEntry + if CrossSync.is_async: from google.cloud.bigtable.data._async.client import TableAsync else: - from google.cloud.bigtable.data._sync.client import Table - from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation - from google.cloud.bigtable.data.mutations import RowMutationEntry + from google.cloud.bigtable.data._sync.client import Table # noqa: F401 + from google.cloud.bigtable.data._sync._mutate_rows import ( # noqa: F401 + _MutateRowsOperation, + ) @CrossSync.sync_output( diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 33cfbf12a..e671f9336 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -82,7 +82,7 @@ async def gather_partials( ): """ abstraction over asyncio.gather, but with a set of partial functions instead - of coroutines, to work with sync functions. + of coroutines, to work with sync functions. To use gather with a set of futures instead of partials, use CrpssSync.wait In the async version, the partials are expected to return an awaitable object. Patials From b52b70be9cd968fa8f1a5083109ffde40ce29586 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 13:20:22 -0700 Subject: [PATCH 083/360] fixed mutations batcher async tests --- google/cloud/bigtable/data/_async/mutations_batcher.py | 3 +-- tests/unit/data/_async/test_mutations_batcher.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index b506da032..f4ce4856b 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -453,8 +453,7 @@ async def close(self): self._closed.set() self._flush_timer.cancel() self._schedule_flush() - await CrossSync.wait(self._flush_jobs) - await CrossSync.wait(self._flush_timer) + await CrossSync.wait([*self._flush_jobs, self._flush_timer]) # shut down executor if self._sync_executor: with self._sync_executor: diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index db6dbd725..170596f6e 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -267,7 +267,7 @@ async def test_add_to_flow_max_mutation_limits( max_limit, ) sync_patch = mock.patch( - "google.cloud.bigtable.data._sync._autogen._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + "google.cloud.bigtable.data.mutations._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", max_limit, ) with async_patch, sync_patch: From a476627c4f0ee11909080966dbc25c6a13848b36 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 14:51:27 -0700 Subject: [PATCH 084/360] fixed mock path --- tests/unit/data/_async/test_client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 5009639aa..d818266f7 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1333,8 +1333,10 @@ async def test_customizable_retryable_errors( if is_stream: retry_fn += "_stream" if self.is_async: - retry_fn += "_async" - with mock.patch(f"google.api_core.retry.{retry_fn}") as retry_fn_mock: + retry_fn = f"CrossSync.{retry_fn}" + else: + retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" + with mock.patch(f"google.cloud.bigtable.data._sync.cross_sync.{retry_fn}") as retry_fn_mock: async with self._make_client() as client: table = client.get_table("instance-id", "table-id") expected_predicate = lambda a: a in expected_retryables # noqa From 948d4a2e0d1f3cfe568c16c92384ff4698fa1ff8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 15:00:40 -0700 Subject: [PATCH 085/360] added event loop check --- google/cloud/bigtable/data/_async/client.py | 3 +++ google/cloud/bigtable/data/_sync/cross_sync.py | 1 + 2 files changed, 4 insertions(+) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index ed9fb8035..978f0826b 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -254,6 +254,9 @@ def _start_background_channel_refresh(self) -> None: and not self._emulator_host and not self._is_closed.is_set() ): + if CrossSync.is_async: + # raise error if not in an event loop + asyncio.get_running_loop() for channel_idx in range(self.transport.pool_size): refresh_task = CrossSync.create_task( self._manage_channel, diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index e671f9336..b256cdbcb 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -163,6 +163,7 @@ async def yield_to_event_loop(): """ await asyncio.sleep(0) + class _Sync_Impl(metaclass=_SyncGetAttr): is_async = False From c6810b69c4d57c35ccad70b8cb80829d7689cb5e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 15:17:02 -0700 Subject: [PATCH 086/360] got client unit tests passing --- google/cloud/bigtable/data/_async/client.py | 1 + tests/unit/data/_async/test_client.py | 32 ++++++++------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 978f0826b..5210143c0 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -280,6 +280,7 @@ async def close(self, timeout: float | None = None): if self._executor: self._executor.shutdown(wait=False) await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) + self._channel_refresh_tasks = [] async def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index d818266f7..f057ac853 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -332,23 +332,16 @@ async def test__ping_and_warm_instances(self): """ test ping and warm with mocked asyncio.gather """ + from google.cloud.bigtable.data._sync.cross_sync import CrossSync client_mock = mock.Mock() client_mock._execute_ping_and_warms = ( lambda *args: self._get_target_class()._execute_ping_and_warms( client_mock, *args ) ) - gather_tuple = ( - (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") - ) - with mock.patch.object(*gather_tuple, AsyncMock()) as gather: - if self.is_async: - # simulate gather by returning the same number of items as passed in - # gather is expected to return None for each coroutine passed - gather.side_effect = lambda *args, **kwargs: [None for _ in args] - else: - # submit is expected to call the function passed, and return the result - gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] + with mock.patch.object(CrossSync, "gather_partials", AsyncMock()) as gather: + # gather_partials is expected to call the function passed, and return the result + gather.side_effect = lambda partials, **kwargs: [None for _ in partials] channel = mock.Mock() # test with no instances client_mock._active_instances = [] @@ -356,8 +349,8 @@ async def test__ping_and_warm_instances(self): client_mock, channel ) assert len(result) == 0 - if self.is_async: - assert gather.call_args.kwargs == {"return_exceptions": True} + assert gather.call_args.kwargs["return_exceptions"] is True + assert gather.call_args.kwargs["sync_executor"] == client_mock._executor # test with instances client_mock._active_instances = [ (mock.Mock(), mock.Mock(), mock.Mock()) @@ -368,14 +361,12 @@ async def test__ping_and_warm_instances(self): client_mock, channel ) assert len(result) == 4 + gather.assert_called_once() + # expect one partial for each instance + partial_list = gather.call_args.args[0] + assert len(partial_list) == 4 if self.is_async: - gather.assert_called_once() gather.assert_awaited_once() - # expect one arg for each instance - assert len(gather.call_args.args) == 4 - else: - # expect one call for each instance - assert gather.call_count == 4 # check grpc call arguments grpc_call_args = channel.unary_unary().call_args_list for idx, (_, kwargs) in enumerate(grpc_call_args): @@ -1052,11 +1043,12 @@ async def test_close(self): @pytest.mark.asyncio async def test_close_with_timeout(self): + from google.cloud.bigtable.data._sync.cross_sync import CrossSync pool_size = 7 expected_timeout = 19 client = self._make_client(project="project-id", pool_size=pool_size) tasks = list(client._channel_refresh_tasks) - with mock.patch.object(asyncio, "wait_for", AsyncMock()) as wait_for_mock: + with mock.patch.object(CrossSync, "wait", AsyncMock()) as wait_for_mock: await client.close(timeout=expected_timeout) wait_for_mock.assert_called_once() wait_for_mock.assert_awaited() From f7b75230dfe21c2477346b82cc43a930dc6a3001 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 15:24:48 -0700 Subject: [PATCH 087/360] added missing await --- google/cloud/bigtable/data/_async/mutations_batcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index f4ce4856b..ecc5635c2 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -316,7 +316,7 @@ async def append(self, mutation_entry: RowMutationEntry): ): self._schedule_flush() # yield to the event loop to allow flush to run - CrossSync.yield_to_event_loop() + await CrossSync.yield_to_event_loop() def _schedule_flush(self) -> CrossSync.Future[None] | None: """Update the flush task to include the latest staged entries""" From d15c4a4be045b0aea074a56c53c75e664e1ab82c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 16:06:10 -0700 Subject: [PATCH 088/360] added typing to cross sync --- .../cloud/bigtable/data/_sync/cross_sync.py | 88 ++++++++++++++----- tests/unit/data/_async/test_client.py | 6 +- 2 files changed, 73 insertions(+), 21 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index b256cdbcb..852d7eb41 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -14,6 +14,7 @@ # from __future__ import annotations +from typing import TypeVar, Any, Awaitable, Callable, Coroutine, Sequence from typing_extensions import TypeAlias import asyncio @@ -24,6 +25,8 @@ import threading import queue +T = TypeVar("T") + class _AsyncGetAttr(type): def __getitem__(cls, item): @@ -59,7 +62,12 @@ def decorator(func): return decorator @classmethod - def sync_output(cls, sync_path, replace_symbols=None, mypy_ignore=None): + def sync_output( + cls, + sync_path: str, + replace_symbols: dict["str", "str"] | None = None, + mypy_ignore: list[str] | None = None, + ): replace_symbols = replace_symbols or {} mypy_ignore = mypy_ignore or [] @@ -78,8 +86,10 @@ def decorator(async_cls): @staticmethod async def gather_partials( - partial_list, return_exceptions=False, sync_executor=None - ): + partial_list: Sequence[Callable[[], Awaitable[T]]], + return_exceptions: bool = False, + sync_executor: concurrent.futures.ThreadPoolExecutor | None = None, + ) -> list[T | BaseException]: """ abstraction over asyncio.gather, but with a set of partial functions instead of coroutines, to work with sync functions. @@ -101,16 +111,23 @@ async def gather_partials( ) @staticmethod - async def wait(futures, timeout=None): + async def wait( + futures: Sequence[CrossSync.Future[T]], timeout: float | None = None + ) -> tuple[set[CrossSync.Future[T]], set[CrossSync.Future[T]]]: """ abstraction over asyncio.wait + + Return: + - a tuple of (done, pending) sets of futures """ if not futures: return set(), set() return await asyncio.wait(futures, timeout=timeout) @staticmethod - async def condition_wait(condition, timeout=None): + async def condition_wait( + condition: CrossSync.Condition, timeout: float | None = None + ) -> bool: """ abstraction over asyncio.Condition.wait @@ -123,7 +140,11 @@ async def condition_wait(condition, timeout=None): return False @staticmethod - async def event_wait(event, timeout=None, async_break_early=True): + async def event_wait( + event: CrossSync.Event, + timeout: float | None = None, + async_break_early: bool = True, + ) -> None: """ abstraction over asyncio.Event.wait @@ -145,25 +166,30 @@ async def event_wait(event, timeout=None, async_break_early=True): pass @staticmethod - def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): + def create_task( + fn: Callable[..., Coroutine[Any, Any, T]], + *fn_args, + sync_executor: concurrent.futures.ThreadPoolExecutor | None = None, + task_name: str | None = None, + **fn_kwargs, + ) -> CrossSync.Task[T]: """ abstraction over asyncio.create_task. Sync version implemented with threadpool executor sync_executor: ThreadPoolExecutor to use for sync operations. Ignored in async version """ - task = asyncio.create_task(fn(*fn_args, **fn_kwargs)) + task: CrossSync.Task[T] = asyncio.create_task(fn(*fn_args, **fn_kwargs)) if task_name and sys.version_info >= (3, 8): task.set_name(task_name) return task @staticmethod - async def yield_to_event_loop(): + async def yield_to_event_loop() -> None: """ Call asyncio.sleep(0) to yield to allow other tasks to run """ await asyncio.sleep(0) - class _Sync_Impl(metaclass=_SyncGetAttr): is_async = False @@ -179,7 +205,12 @@ class _Sync_Impl(metaclass=_SyncGetAttr): generated_replacements: dict[type, str] = {} @staticmethod - def wait(futures, timeout=None): + def wait( + futures: Sequence[CrossSync._Sync_Impl.Future[T]], + timeout: float | None = None, + ) -> tuple[ + set[CrossSync._Sync_Impl.Future[T]], set[CrossSync._Sync_Impl.Future[T]] + ]: """ abstraction over asyncio.wait """ @@ -188,36 +219,53 @@ def wait(futures, timeout=None): return concurrent.futures.wait(futures, timeout=timeout) @staticmethod - def condition_wait(condition, timeout=None): + def condition_wait( + condition: CrossSync._Sync_Impl.Condition, timeout: float | None = None + ) -> bool: """ returns False if the timeout is reached before the condition is set, otherwise True """ return condition.wait(timeout=timeout) @staticmethod - def event_wait(event, timeout=None, async_break_early=True): + def event_wait( + event: CrossSync._Sync_Impl.Event, + timeout: float | None = None, + async_break_early: bool = True, + ) -> None: event.wait(timeout=timeout) @staticmethod - def gather_partials(partial_list, return_exceptions=False, sync_executor=None): + def gather_partials( + partial_list: Sequence[Callable[[], T]], + return_exceptions: bool = False, + sync_executor: concurrent.futures.ThreadPoolExecutor | None = None, + ) -> list[T | BaseException]: if not partial_list: return [] if not sync_executor: raise ValueError("sync_executor is required for sync version") futures_list = [sync_executor.submit(partial) for partial in partial_list] - results_list = [] + results_list: list[T | BaseException] = [] for future in futures_list: - if future.exception(): + found_exc = future.exception() + if found_exc is not None: if return_exceptions: - results_list.append(future.exception()) + results_list.append(found_exc) else: - raise future.exception() + raise found_exc else: results_list.append(future.result()) return results_list @staticmethod - def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): + def create_task( + fn: Callable[..., T], + *fn_args, + sync_executor: concurrent.futures.ThreadPoolExecutor | None = None, + task_name: str | None = None, + **fn_kwargs, + ) -> CrossSync._Sync_Impl.Task[T]: """ abstraction over asyncio.create_task. Sync version implemented with threadpool executor @@ -228,5 +276,5 @@ def create_task(fn, *fn_args, sync_executor=None, task_name=None, **fn_kwargs): return sync_executor.submit(fn, *fn_args, **fn_kwargs) @staticmethod - def yield_to_event_loop(): + def yield_to_event_loop() -> None: pass diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index f057ac853..786d87c56 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -333,6 +333,7 @@ async def test__ping_and_warm_instances(self): test ping and warm with mocked asyncio.gather """ from google.cloud.bigtable.data._sync.cross_sync import CrossSync + client_mock = mock.Mock() client_mock._execute_ping_and_warms = ( lambda *args: self._get_target_class()._execute_ping_and_warms( @@ -1044,6 +1045,7 @@ async def test_close(self): @pytest.mark.asyncio async def test_close_with_timeout(self): from google.cloud.bigtable.data._sync.cross_sync import CrossSync + pool_size = 7 expected_timeout = 19 client = self._make_client(project="project-id", pool_size=pool_size) @@ -1328,7 +1330,9 @@ async def test_customizable_retryable_errors( retry_fn = f"CrossSync.{retry_fn}" else: retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" - with mock.patch(f"google.cloud.bigtable.data._sync.cross_sync.{retry_fn}") as retry_fn_mock: + with mock.patch( + f"google.cloud.bigtable.data._sync.cross_sync.{retry_fn}" + ) as retry_fn_mock: async with self._make_client() as client: table = client.get_table("instance-id", "table-id") expected_predicate = lambda a: a in expected_retryables # noqa From 360a20441d7a2c2e63ec099ee8da32b7228077ad Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 16:25:05 -0700 Subject: [PATCH 089/360] fixed type errors --- google/cloud/bigtable/data/_async/client.py | 4 ++-- google/cloud/bigtable/data/_async/mutations_batcher.py | 2 +- google/cloud/bigtable/data/_sync/client.py | 7 +++++-- google/cloud/bigtable/data/_sync/cross_sync.py | 2 +- google/cloud/bigtable/data/_sync/mutations_batcher.py | 5 ++--- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 5210143c0..f19efc0c5 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -192,7 +192,7 @@ def __init__( client_options=client_options, client_info=client_info, ) - self._is_closed = asyncio.Event() + self._is_closed = CrossSync.Event() self.transport = cast( PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport ) @@ -202,7 +202,7 @@ def __init__( # only remove instance from _active_instances when all associated tables remove it self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() - self._channel_refresh_tasks: list[asyncio.Task[None]] = [] + self._channel_refresh_tasks: list[CrossSync.Task[None]] = [] self._executor = ( concurrent.futures.ThreadPoolExecutor() if not CrossSync.is_async else None ) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index ecc5635c2..bd2bc2963 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -284,7 +284,7 @@ async def _timer_routine(self, interval: float | None) -> None: return None while not self._closed.is_set(): # wait until interval has passed, or until closed - await CrossSync.condition_wait(self._closed, timeout=interval) + await CrossSync.event_wait(self._closed, timeout=interval) if not self._closed.is_set() and self._staged_entries: self._schedule_flush() diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index bc8ec800f..2f7528cec 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -140,12 +140,12 @@ def __init__( client_options=client_options, client_info=client_info, ) - self._is_closed = asyncio.Event() + self._is_closed = CrossSync._Sync_Impl.Event() self.transport = cast(PooledBigtableGrpcTransport, self._gapic_client.transport) self._active_instances: Set[_helpers._WarmedInstanceKey] = set() self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() - self._channel_refresh_tasks: list[asyncio.Task[None]] = [] + self._channel_refresh_tasks: list[CrossSync._Sync_Impl.Task[None]] = [] self._executor = ( concurrent.futures.ThreadPoolExecutor() if not CrossSync._Sync_Impl.is_async @@ -191,6 +191,8 @@ def _start_background_channel_refresh(self) -> None: and (not self._emulator_host) and (not self._is_closed.is_set()) ): + if CrossSync._Sync_Impl.is_async: + asyncio.get_running_loop() for channel_idx in range(self.transport.pool_size): refresh_task = CrossSync._Sync_Impl.create_task( self._manage_channel, @@ -212,6 +214,7 @@ def close(self, timeout: float | None = None): if self._executor: self._executor.shutdown(wait=False) CrossSync._Sync_Impl.wait(self._channel_refresh_tasks, timeout=timeout) + self._channel_refresh_tasks = [] def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 852d7eb41..b6c4cb0f9 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -65,7 +65,7 @@ def decorator(func): def sync_output( cls, sync_path: str, - replace_symbols: dict["str", "str"] | None = None, + replace_symbols: dict["str", "str" | None ] | None = None, mypy_ignore: list[str] | None = None, ): replace_symbols = replace_symbols or {} diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index d4e6c7f0c..8c1dc4904 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -131,7 +131,7 @@ def _timer_routine(self, interval: float | None) -> None: if not interval or interval <= 0: return None while not self._closed.is_set(): - CrossSync._Sync_Impl.condition_wait(self._closed, timeout=interval) + CrossSync._Sync_Impl.event_wait(self._closed, timeout=interval) if not self._closed.is_set() and self._staged_entries: self._schedule_flush() @@ -289,8 +289,7 @@ def close(self): self._closed.set() self._flush_timer.cancel() self._schedule_flush() - CrossSync._Sync_Impl.wait(self._flush_jobs) - CrossSync._Sync_Impl.wait(self._flush_timer) + CrossSync._Sync_Impl.wait([*self._flush_jobs, self._flush_timer]) if self._sync_executor: with self._sync_executor: self._sync_executor.shutdown(wait=True) From 99b23e50bd7191e592ad4282baf2a4f7efc792d7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 16:25:38 -0700 Subject: [PATCH 090/360] remove unused metaclasses --- google/cloud/bigtable/data/_async/client.py | 6 ++++-- google/cloud/bigtable/data/_sync/client.py | 4 ++-- google/cloud/bigtable/data/_sync/cross_sync.py | 15 ++------------- 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index f19efc0c5..9d7616617 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -119,6 +119,8 @@ "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", "BigtableAsyncClient": "BigtableClient", "AsyncPooledChannel": "PooledChannel", + "_ReadRowsOperationAsync": "_ReadRowsOperation", + "_MutateRowsOperationAsync": "_MutateRowsOperation", }, ) class BigtableDataClientAsync(ClientWithProject): @@ -665,7 +667,7 @@ async def read_rows_stream( ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - row_merger = CrossSync[_ReadRowsOperationAsync]( + row_merger = _ReadRowsOperationAsync( query, self, operation_timeout=operation_timeout, @@ -1180,7 +1182,7 @@ async def bulk_mutate_rows( ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - operation = CrossSync[_MutateRowsOperationAsync]( + operation = _MutateRowsOperationAsync( self.client._gapic_client, self, mutation_entries, diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 2f7528cec..975b379a2 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -573,7 +573,7 @@ def read_rows_stream( operation_timeout, attempt_timeout, self ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - row_merger = CrossSync._Sync_Impl[_ReadRowsOperationAsync]( + row_merger = _ReadRowsOperationAsync( query, self, operation_timeout=operation_timeout, @@ -1070,7 +1070,7 @@ def bulk_mutate_rows( operation_timeout, attempt_timeout, self ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - operation = CrossSync._Sync_Impl[_MutateRowsOperationAsync]( + operation = _MutateRowsOperationAsync( self.client._gapic_client, self, mutation_entries, diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index b6c4cb0f9..9722cf4e8 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -28,18 +28,7 @@ T = TypeVar("T") -class _AsyncGetAttr(type): - def __getitem__(cls, item): - return item - - -class _SyncGetAttr(type): - def __getitem__(cls, item): - breakpoint() - return CrossSync.generated_replacements[item] - - -class CrossSync(metaclass=_AsyncGetAttr): +class CrossSync: SyncImports = False is_async = True @@ -190,7 +179,7 @@ async def yield_to_event_loop() -> None: """ await asyncio.sleep(0) - class _Sync_Impl(metaclass=_SyncGetAttr): + class _Sync_Impl: is_async = False sleep = time.sleep From 9246589b3912300fb3e7c5ec6d0d8e314783515a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 16:38:01 -0700 Subject: [PATCH 091/360] fixed warning in tests --- google/cloud/bigtable/data/_async/mutations_batcher.py | 2 +- tests/unit/data/_async/test_mutations_batcher.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index bd2bc2963..8aab2cb9c 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -284,7 +284,7 @@ async def _timer_routine(self, interval: float | None) -> None: return None while not self._closed.is_set(): # wait until interval has passed, or until closed - await CrossSync.event_wait(self._closed, timeout=interval) + await CrossSync.event_wait(self._closed, timeout=interval, async_break_early=False) if not self._closed.is_set() and self._staged_entries: self._schedule_flush() diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 170596f6e..b76eee300 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -552,6 +552,7 @@ async def test__start_flush_timer_call_when_closed( @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test__flush_timer(self, num_staged): """Timer should continue to call _schedule_flush in a loop""" + from google.cloud.bigtable.data._sync.cross_sync import CrossSync with mock.patch.object( self._get_target_class(), "_schedule_flush" ) as flush_mock: @@ -559,12 +560,7 @@ async def test__flush_timer(self, num_staged): async with self._make_one(flush_interval=expected_sleep) as instance: loop_num = 3 instance._staged_entries = [mock.Mock()] * num_staged - # mock different method depending on sync vs async - if self.is_async(): - sleep_obj, sleep_method = asyncio, "wait_for" - else: - sleep_obj, sleep_method = instance._closed, "wait" - with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + with mock.patch.object(CrossSync, "event_wait") as sleep_mock: sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] with pytest.raises(TabError): await self._get_target_class()._timer_routine( From 9cf3923235618558ca9dc220bd306ecbfe1c3b4e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 16:50:32 -0700 Subject: [PATCH 092/360] removed old sync tests --- tests/system/data/test_system.py | 782 ----- tests/unit/data/_sync/__init__.py | 0 tests/unit/data/_sync/test_autogen.py | 4678 ------------------------- 3 files changed, 5460 deletions(-) delete mode 100644 tests/system/data/test_system.py delete mode 100644 tests/unit/data/_sync/__init__.py delete mode 100644 tests/unit/data/_sync/test_autogen.py diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py deleted file mode 100644 index 33710d808..000000000 --- a/tests/system/data/test_system.py +++ /dev/null @@ -1,782 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is automatically generated by sync_surface_generator.py. Do not edit. - - -from .test_system_async import TEST_FAMILY, TEST_FAMILY_2 -from abc import ABC -import os -import pytest -import time -import uuid - -from google.api_core import retry -from google.api_core.exceptions import ClientError -from google.cloud.bigtable.data import BigtableDataClient -from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE -from google.cloud.environment_vars import BIGTABLE_EMULATOR - - -class TempRowBuilder(ABC): - """ - Used to add rows to a table for testing purposes. - """ - - def __init__(self, table): - self.rows = [] - self.table = table - - def add_row( - self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" - ): - if isinstance(value, str): - value = value.encode("utf-8") - elif isinstance(value, int): - value = value.to_bytes(8, byteorder="big", signed=True) - request = { - "table_name": self.table.table_name, - "row_key": row_key, - "mutations": [ - { - "set_cell": { - "family_name": family, - "column_qualifier": qualifier, - "value": value, - } - } - ], - } - self.table.client._gapic_client.mutate_row(request) - self.rows.append(row_key) - - def delete_rows(self): - if self.rows: - request = { - "table_name": self.table.table_name, - "entries": [ - {"row_key": row, "mutations": [{"delete_from_row": {}}]} - for row in self.rows - ], - } - self.table.client._gapic_client.mutate_rows(request) - - -class TestSystemSync(ABC): - @pytest.fixture(scope="session") - def client(self): - project = os.getenv("GOOGLE_CLOUD_PROJECT") or None - with BigtableDataClient(project=project, pool_size=4) as client: - yield client - - @pytest.fixture(scope="session") - def table(self, client, table_id, instance_id): - with client.get_table(instance_id, table_id) as table: - yield table - - @pytest.fixture(scope="session") - def column_family_config(self): - """specify column families to create when creating a new test table""" - from google.cloud.bigtable_admin_v2 import types - - return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} - - @pytest.fixture(scope="session") - def init_table_id(self): - """The table_id to use when creating a new test table""" - return f"test-table-{uuid.uuid4().hex}" - - @pytest.fixture(scope="session") - def cluster_config(self, project_id): - """Configuration for the clusters to use when creating a new instance""" - from google.cloud.bigtable_admin_v2 import types - - cluster = { - "test-cluster": types.Cluster( - location=f"projects/{project_id}/locations/us-central1-b", serve_nodes=1 - ) - } - return cluster - - @pytest.mark.usefixtures("table") - def _retrieve_cell_value(self, table, row_key): - """Helper to read an individual row""" - from google.cloud.bigtable.data import ReadRowsQuery - - row_list = table.read_rows(ReadRowsQuery(row_keys=row_key)) - assert len(row_list) == 1 - row = row_list[0] - cell = row.cells[0] - return cell.value - - def _create_row_and_mutation( - self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" - ): - """Helper to create a new row, and a sample set_cell mutation to change its value""" - from google.cloud.bigtable.data.mutations import SetCell - - row_key = uuid.uuid4().hex.encode() - family = TEST_FAMILY - qualifier = b"test-qualifier" - temp_rows.add_row( - row_key, family=family, qualifier=qualifier, value=start_value - ) - assert self._retrieve_cell_value(table, row_key) == start_value - mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) - return (row_key, mutation) - - @pytest.fixture(scope="function") - def temp_rows(self, table): - builder = TempRowBuilder(table) - yield builder - builder.delete_rows() - - @pytest.mark.usefixtures("table") - @pytest.mark.usefixtures("client") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=10) - def test_ping_and_warm_gapic(self, client, table): - """ - Simple ping rpc test - This test ensures channels are able to authenticate with backend - """ - request = {"name": table.instance_name} - client._gapic_client.ping_and_warm(request) - - @pytest.mark.usefixtures("table") - @pytest.mark.usefixtures("client") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_ping_and_warm(self, client, table): - """Test ping and warm from handwritten client""" - try: - channel = client.transport._grpc_channel.pool[0] - except Exception: - channel = client.transport._grpc_channel - results = client._ping_and_warm_instances(channel) - assert len(results) == 1 - assert results[0] is None - - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_mutation_set_cell(self, table, temp_rows): - """Ensure cells can be set properly""" - row_key = b"bulk_mutate" - new_value = uuid.uuid4().hex.encode() - (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - table.mutate_row(row_key, mutation) - assert self._retrieve_cell_value(table, row_key) == new_value - - @pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" - ) - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_sample_row_keys(self, client, table, temp_rows, column_split_config): - """Sample keys should return a single sample in small test tables""" - temp_rows.add_row(b"row_key_1") - temp_rows.add_row(b"row_key_2") - results = table.sample_row_keys() - assert len(results) == len(column_split_config) + 1 - for idx in range(len(column_split_config)): - assert results[idx][0] == column_split_config[idx] - assert isinstance(results[idx][1], int) - assert results[-1][0] == b"" - assert isinstance(results[-1][1], int) - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - def test_bulk_mutations_set_cell(self, client, table, temp_rows): - """Ensure cells can be set properly""" - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - table.bulk_mutate_rows([bulk_mutation]) - assert self._retrieve_cell_value(table, row_key) == new_value - - def test_bulk_mutations_raise_exception(self, client, table): - """If an invalid mutation is passed, an exception should be raised""" - from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell - from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup - from google.cloud.bigtable.data.exceptions import FailedMutationEntryError - - row_key = uuid.uuid4().hex.encode() - mutation = SetCell( - family="nonexistent", qualifier=b"test-qualifier", new_value=b"" - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - with pytest.raises(MutationsExceptionGroup) as exc: - table.bulk_mutate_rows([bulk_mutation]) - assert len(exc.value.exceptions) == 1 - entry_error = exc.value.exceptions[0] - assert isinstance(entry_error, FailedMutationEntryError) - assert entry_error.index == 0 - assert entry_error.entry == bulk_mutation - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_mutations_batcher_context_manager(self, client, table, temp_rows): - """test batcher with context manager. Should flush on exit""" - from google.cloud.bigtable.data.mutations import RowMutationEntry - - (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] - (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - (row_key2, mutation2) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - with table.mutations_batcher() as batcher: - batcher.append(bulk_mutation) - batcher.append(bulk_mutation2) - assert self._retrieve_cell_value(table, row_key) == new_value - assert len(batcher._staged_entries) == 0 - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_mutations_batcher_timer_flush(self, client, table, temp_rows): - """batch should occur after flush_interval seconds""" - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - flush_interval = 0.1 - with table.mutations_batcher(flush_interval=flush_interval) as batcher: - batcher.append(bulk_mutation) - time.sleep(0) - assert len(batcher._staged_entries) == 1 - time.sleep(flush_interval + 0.1) - assert len(batcher._staged_entries) == 0 - assert self._retrieve_cell_value(table, row_key) == new_value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_mutations_batcher_count_flush(self, client, table, temp_rows): - """batch should flush after flush_limit_mutation_count mutations""" - from google.cloud.bigtable.data.mutations import RowMutationEntry - - (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] - (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - (row_key2, mutation2) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: - batcher.append(bulk_mutation) - assert len(batcher._flush_jobs) == 0 - assert len(batcher._staged_entries) == 1 - batcher.append(bulk_mutation2) - assert len(batcher._flush_jobs) == 1 - for future in list(batcher._flush_jobs): - future - future.result() - assert len(batcher._staged_entries) == 0 - assert len(batcher._flush_jobs) == 0 - assert self._retrieve_cell_value(table, row_key) == new_value - assert self._retrieve_cell_value(table, row_key2) == new_value2 - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): - """batch should flush after flush_limit_bytes bytes""" - from google.cloud.bigtable.data.mutations import RowMutationEntry - - (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] - (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - (row_key2, mutation2) = self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 - with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: - batcher.append(bulk_mutation) - assert len(batcher._flush_jobs) == 0 - assert len(batcher._staged_entries) == 1 - batcher.append(bulk_mutation2) - assert len(batcher._flush_jobs) == 1 - assert len(batcher._staged_entries) == 0 - for future in list(batcher._flush_jobs): - future - future.result() - assert self._retrieve_cell_value(table, row_key) == new_value - assert self._retrieve_cell_value(table, row_key2) == new_value2 - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - def test_mutations_batcher_no_flush(self, client, table, temp_rows): - """test with no flush requirements met""" - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - start_value = b"unchanged" - (row_key, mutation) = self._create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - (row_key2, mutation2) = self._create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 - with table.mutations_batcher( - flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 - ) as batcher: - batcher.append(bulk_mutation) - assert len(batcher._staged_entries) == 1 - batcher.append(bulk_mutation2) - assert len(batcher._flush_jobs) == 0 - time.sleep(0.01) - assert len(batcher._staged_entries) == 2 - assert len(batcher._flush_jobs) == 0 - assert self._retrieve_cell_value(table, row_key) == start_value - assert self._retrieve_cell_value(table, row_key2) == start_value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @pytest.mark.parametrize( - "start,increment,expected", - [ - (0, 0, 0), - (0, 1, 1), - (0, -1, -1), - (1, 0, 1), - (0, -100, -100), - (0, 3000, 3000), - (10, 4, 14), - (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), - (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), - (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), - ], - ) - def test_read_modify_write_row_increment( - self, client, table, temp_rows, start, increment, expected - ): - """test read_modify_write_row""" - from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) - rule = IncrementRule(family, qualifier, increment) - result = table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert len(result) == 1 - assert result[0].family == family - assert result[0].qualifier == qualifier - assert int(result[0]) == expected - assert self._retrieve_cell_value(table, row_key) == result[0].value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @pytest.mark.parametrize( - "start,append,expected", - [ - (b"", b"", b""), - ("", "", b""), - (b"abc", b"123", b"abc123"), - (b"abc", "123", b"abc123"), - ("", b"1", b"1"), - (b"abc", "", b"abc"), - (b"hello", b"world", b"helloworld"), - ], - ) - def test_read_modify_write_row_append( - self, client, table, temp_rows, start, append, expected - ): - """test read_modify_write_row""" - from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) - rule = AppendValueRule(family, qualifier, append) - result = table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert len(result) == 1 - assert result[0].family == family - assert result[0].qualifier == qualifier - assert result[0].value == expected - assert self._retrieve_cell_value(table, row_key) == result[0].value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - def test_read_modify_write_row_chained(self, client, table, temp_rows): - """test read_modify_write_row with multiple rules""" - from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule - from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - start_amount = 1 - increment_amount = 10 - temp_rows.add_row( - row_key, value=start_amount, family=family, qualifier=qualifier - ) - rule = [ - IncrementRule(family, qualifier, increment_amount), - AppendValueRule(family, qualifier, "hello"), - AppendValueRule(family, qualifier, "world"), - AppendValueRule(family, qualifier, "!"), - ] - result = table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert result[0].family == family - assert result[0].qualifier == qualifier - assert ( - result[0].value - == (start_amount + increment_amount).to_bytes(8, "big", signed=True) - + b"helloworld!" - ) - assert self._retrieve_cell_value(table, row_key) == result[0].value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @pytest.mark.parametrize( - "start_val,predicate_range,expected_result", - [(1, (0, 2), True), (-1, (0, 2), False)], - ) - def test_check_and_mutate( - self, client, table, temp_rows, start_val, predicate_range, expected_result - ): - """test that check_and_mutate_row works applies the right mutations, and returns the right result""" - from google.cloud.bigtable.data.mutations import SetCell - from google.cloud.bigtable.data.row_filters import ValueRangeFilter - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - temp_rows.add_row(row_key, value=start_val, family=family, qualifier=qualifier) - false_mutation_value = b"false-mutation-value" - false_mutation = SetCell( - family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value - ) - true_mutation_value = b"true-mutation-value" - true_mutation = SetCell( - family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value - ) - predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) - result = table.check_and_mutate_row( - row_key, - predicate, - true_case_mutations=true_mutation, - false_case_mutations=false_mutation, - ) - assert result == expected_result - expected_value = ( - true_mutation_value if expected_result else false_mutation_value - ) - assert self._retrieve_cell_value(table, row_key) == expected_value - - @pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", - ) - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - def test_check_and_mutate_empty_request(self, client, table): - """check_and_mutate with no true or fale mutations should raise an error""" - from google.api_core import exceptions - - with pytest.raises(exceptions.InvalidArgument) as e: - table.check_and_mutate_row( - b"row_key", None, true_case_mutations=None, false_case_mutations=None - ) - assert "No mutations provided" in str(e.value) - - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_read_rows_stream(self, table, temp_rows): - """Ensure that the read_rows_stream method works""" - temp_rows.add_row(b"row_key_1") - temp_rows.add_row(b"row_key_2") - generator = table.read_rows_stream({}) - first_row = generator.__next__() - second_row = generator.__next__() - assert first_row.row_key == b"row_key_1" - assert second_row.row_key == b"row_key_2" - with pytest.raises(StopIteration): - generator.__next__() - - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_read_rows(self, table, temp_rows): - """Ensure that the read_rows method works""" - temp_rows.add_row(b"row_key_1") - temp_rows.add_row(b"row_key_2") - row_list = table.read_rows({}) - assert len(row_list) == 2 - assert row_list[0].row_key == b"row_key_1" - assert row_list[1].row_key == b"row_key_2" - - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_read_rows_sharded_simple(self, table, temp_rows): - """Test read rows sharded with two queries""" - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - - temp_rows.add_row(b"a") - temp_rows.add_row(b"b") - temp_rows.add_row(b"c") - temp_rows.add_row(b"d") - query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) - query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) - row_list = table.read_rows_sharded([query1, query2]) - assert len(row_list) == 4 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"c" - assert row_list[2].row_key == b"b" - assert row_list[3].row_key == b"d" - - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_read_rows_sharded_from_sample(self, table, temp_rows): - """Test end-to-end sharding""" - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - from google.cloud.bigtable.data.read_rows_query import RowRange - - temp_rows.add_row(b"a") - temp_rows.add_row(b"b") - temp_rows.add_row(b"c") - temp_rows.add_row(b"d") - table_shard_keys = table.sample_row_keys() - query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) - shard_queries = query.shard(table_shard_keys) - row_list = table.read_rows_sharded(shard_queries) - assert len(row_list) == 3 - assert row_list[0].row_key == b"b" - assert row_list[1].row_key == b"c" - assert row_list[2].row_key == b"d" - - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_read_rows_sharded_filters_limits(self, table, temp_rows): - """Test read rows sharded with filters and limits""" - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - temp_rows.add_row(b"a") - temp_rows.add_row(b"b") - temp_rows.add_row(b"c") - temp_rows.add_row(b"d") - label_filter1 = ApplyLabelFilter("first") - label_filter2 = ApplyLabelFilter("second") - query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) - query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) - row_list = table.read_rows_sharded([query1, query2]) - assert len(row_list) == 3 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"b" - assert row_list[2].row_key == b"d" - assert row_list[0][0].labels == ["first"] - assert row_list[1][0].labels == ["second"] - assert row_list[2][0].labels == ["second"] - - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_read_rows_range_query(self, table, temp_rows): - """Ensure that the read_rows method works""" - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable.data import RowRange - - temp_rows.add_row(b"a") - temp_rows.add_row(b"b") - temp_rows.add_row(b"c") - temp_rows.add_row(b"d") - query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) - row_list = table.read_rows(query) - assert len(row_list) == 2 - assert row_list[0].row_key == b"b" - assert row_list[1].row_key == b"c" - - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_read_rows_single_key_query(self, table, temp_rows): - """Ensure that the read_rows method works with specified query""" - from google.cloud.bigtable.data import ReadRowsQuery - - temp_rows.add_row(b"a") - temp_rows.add_row(b"b") - temp_rows.add_row(b"c") - temp_rows.add_row(b"d") - query = ReadRowsQuery(row_keys=[b"a", b"c"]) - row_list = table.read_rows(query) - assert len(row_list) == 2 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"c" - - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - def test_read_rows_with_filter(self, table, temp_rows): - """ensure filters are applied""" - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - temp_rows.add_row(b"a") - temp_rows.add_row(b"b") - temp_rows.add_row(b"c") - temp_rows.add_row(b"d") - expected_label = "test-label" - row_filter = ApplyLabelFilter(expected_label) - query = ReadRowsQuery(row_filter=row_filter) - row_list = table.read_rows(query) - assert len(row_list) == 4 - for row in row_list: - assert row[0].labels == [expected_label] - - @pytest.mark.usefixtures("table") - def test_read_rows_stream_close(self, table, temp_rows): - """Ensure that the read_rows_stream can be closed""" - from google.cloud.bigtable.data import ReadRowsQuery - - temp_rows.add_row(b"row_key_1") - temp_rows.add_row(b"row_key_2") - query = ReadRowsQuery() - generator = table.read_rows_stream(query) - first_row = generator.__next__() - assert first_row.row_key == b"row_key_1" - generator.close() - with pytest.raises(StopIteration): - generator.__next__() - - @pytest.mark.usefixtures("table") - def test_read_row(self, table, temp_rows): - """Test read_row (single row helper)""" - from google.cloud.bigtable.data import Row - - temp_rows.add_row(b"row_key_1", value=b"value") - row = table.read_row(b"row_key_1") - assert isinstance(row, Row) - assert row.row_key == b"row_key_1" - assert row.cells[0].value == b"value" - - @pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", - ) - @pytest.mark.usefixtures("table") - def test_read_row_missing(self, table): - """Test read_row when row does not exist""" - from google.api_core import exceptions - - row_key = "row_key_not_exist" - result = table.read_row(row_key) - assert result is None - with pytest.raises(exceptions.InvalidArgument) as e: - table.read_row("") - assert "Row keys must be non-empty" in str(e) - - @pytest.mark.usefixtures("table") - def test_read_row_w_filter(self, table, temp_rows): - """Test read_row (single row helper)""" - from google.cloud.bigtable.data import Row - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - temp_rows.add_row(b"row_key_1", value=b"value") - expected_label = "test-label" - label_filter = ApplyLabelFilter(expected_label) - row = table.read_row(b"row_key_1", row_filter=label_filter) - assert isinstance(row, Row) - assert row.row_key == b"row_key_1" - assert row.cells[0].value == b"value" - assert row.cells[0].labels == [expected_label] - - @pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", - ) - @pytest.mark.usefixtures("table") - def test_row_exists(self, table, temp_rows): - from google.api_core import exceptions - - "Test row_exists with rows that exist and don't exist" - assert table.row_exists(b"row_key_1") is False - temp_rows.add_row(b"row_key_1") - assert table.row_exists(b"row_key_1") is True - assert table.row_exists("row_key_1") is True - assert table.row_exists(b"row_key_2") is False - assert table.row_exists("row_key_2") is False - assert table.row_exists("3") is False - temp_rows.add_row(b"3") - assert table.row_exists(b"3") is True - with pytest.raises(exceptions.InvalidArgument) as e: - table.row_exists("") - assert "Row keys must be non-empty" in str(e) - - @pytest.mark.usefixtures("table") - @retry.Retry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) - @pytest.mark.parametrize( - "cell_value,filter_input,expect_match", - [ - (b"abc", b"abc", True), - (b"abc", "abc", True), - (b".", ".", True), - (".*", ".*", True), - (".*", b".*", True), - ("a", ".*", False), - (b".*", b".*", True), - ("\\a", "\\a", True), - (b"\xe2\x98\x83", "☃", True), - ("☃", "☃", True), - ("\\C☃", "\\C☃", True), - (1, 1, True), - (2, 1, False), - (68, 68, True), - ("D", 68, False), - (68, "D", False), - (-1, -1, True), - (2852126720, 2852126720, True), - (-1431655766, -1431655766, True), - (-1431655766, -1, False), - ], - ) - def test_literal_value_filter( - self, table, temp_rows, cell_value, filter_input, expect_match - ): - """ - Literal value filter does complex escaping on re2 strings. - Make sure inputs are properly interpreted by the server - """ - from google.cloud.bigtable.data.row_filters import LiteralValueFilter - from google.cloud.bigtable.data import ReadRowsQuery - - f = LiteralValueFilter(filter_input) - temp_rows.add_row(b"row_key_1", value=cell_value) - query = ReadRowsQuery(row_filter=f) - row_list = table.read_rows(query) - assert len(row_list) == bool( - expect_match - ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" diff --git a/tests/unit/data/_sync/__init__.py b/tests/unit/data/_sync/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/data/_sync/test_autogen.py b/tests/unit/data/_sync/test_autogen.py deleted file mode 100644 index fda0fee15..000000000 --- a/tests/unit/data/_sync/test_autogen.py +++ /dev/null @@ -1,4678 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is automatically generated by sync_surface_generator.py. Do not edit. - - -from __future__ import annotations -from abc import ABC -from itertools import zip_longest -from tests.unit.data._async.test__mutate_rows import TestMutateRowsOperation -from tests.unit.data._async.test__read_rows import TestReadRowsOperation -from tests.unit.data._async.test_mutations_batcher import Test_FlowControl -from tests.unit.v2_client.test_row_merger import ReadRowsTest -from tests.unit.v2_client.test_row_merger import TestFile -from unittest import mock -import asyncio -import concurrent.futures -import grpc -import mock -import os -import pytest -import re -import threading -import time -import warnings - -from google.api_core import exceptions as core_exceptions -from google.api_core import grpc_helpers -from google.auth.credentials import AnonymousCredentials -from google.cloud.bigtable.data import ReadRowsQuery -from google.cloud.bigtable.data import TABLE_DEFAULT -from google.cloud.bigtable.data import Table -from google.cloud.bigtable.data import mutations -from google.cloud.bigtable.data.exceptions import InvalidChunk -from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete -from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule -from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule -from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery -from google.cloud.bigtable.data.row import Row -from google.cloud.bigtable_v2 import ReadRowsResponse -from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledBigtableGrpcTransport, -) -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledChannel, -) -from google.cloud.bigtable_v2.types import MutateRowsResponse -from google.cloud.bigtable_v2.types import ReadRowsResponse -from google.rpc import status_pb2 -import google.api_core.exceptions -import google.api_core.exceptions as core_exceptions -import google.api_core.retry - - -class TestMutateRowsOperation(ABC): - def _target_class(self): - from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation - - return _MutateRowsOperation - - def _make_one(self, *args, **kwargs): - if not args: - kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) - kwargs["table"] = kwargs.pop("table", mock.Mock()) - kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) - kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) - kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) - kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) - return self._target_class()(*args, **kwargs) - - def _make_mutation(self, count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation - - def _mock_stream(self, mutation_list, error_dict): - for idx, entry in enumerate(mutation_list): - code = error_dict.get(idx, 0) - yield MutateRowsResponse( - entries=[ - MutateRowsResponse.Entry( - index=idx, status=status_pb2.Status(code=code) - ) - ] - ) - - def _make_mock_gapic(self, mutation_list, error_dict=None): - mock_fn = mock.Mock() - if error_dict is None: - error_dict = {} - mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( - mutation_list, error_dict - ) - return mock_fn - - def test_ctor(self): - """test that constructor sets all the attributes correctly""" - from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto - from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - from google.api_core.exceptions import DeadlineExceeded - from google.api_core.exceptions import Aborted - - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] - operation_timeout = 0.05 - attempt_timeout = 0.01 - retryable_exceptions = () - instance = self._make_one( - client, - table, - entries, - operation_timeout, - attempt_timeout, - retryable_exceptions, - ) - assert client.mutate_rows.call_count == 0 - instance._gapic_fn() - assert client.mutate_rows.call_count == 1 - inner_kwargs = client.mutate_rows.call_args[1] - assert len(inner_kwargs) == 4 - assert inner_kwargs["table_name"] == table.table_name - assert inner_kwargs["app_profile_id"] == table.app_profile_id - assert inner_kwargs["retry"] is None - metadata = inner_kwargs["metadata"] - assert len(metadata) == 1 - assert metadata[0][0] == "x-goog-request-params" - assert str(table.table_name) in metadata[0][1] - assert str(table.app_profile_id) in metadata[0][1] - entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries] - assert instance.mutations == entries_w_pb - assert next(instance.timeout_generator) == attempt_timeout - assert instance.is_retryable is not None - assert instance.is_retryable(DeadlineExceeded("")) is False - assert instance.is_retryable(Aborted("")) is False - assert instance.is_retryable(_MutateRowsIncomplete("")) is True - assert instance.is_retryable(RuntimeError("")) is False - assert instance.remaining_indices == list(range(len(entries))) - assert instance.errors == {} - - def test_ctor_too_many_entries(self): - """should raise an error if an operation is created with more than 100,000 entries""" - from google.cloud.bigtable.data._async._mutate_rows import ( - _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, - ) - - assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100000 - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation()] * (_MUTATE_ROWS_REQUEST_MUTATION_LIMIT + 1) - operation_timeout = 0.05 - attempt_timeout = 0.01 - with pytest.raises(ValueError) as e: - self._make_one(client, table, entries, operation_timeout, attempt_timeout) - assert "mutate_rows requests can contain at most 100000 mutations" in str( - e.value - ) - assert "Found 100001" in str(e.value) - - def test_mutate_rows_operation(self): - """Test successful case of mutate_rows_operation""" - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] - operation_timeout = 0.05 - cls = self._target_class() - with mock.patch( - f"{cls.__module__}.{cls.__name__}._run_attempt", mock.Mock() - ) as attempt_mock: - instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout - ) - instance.start() - assert attempt_mock.call_count == 1 - - @pytest.mark.parametrize( - "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] - ) - def test_mutate_rows_attempt_exception(self, exc_type): - """exceptions raised from attempt should be raised in MutationsExceptionGroup""" - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] - operation_timeout = 0.05 - expected_exception = exc_type("test") - client.mutate_rows.side_effect = expected_exception - found_exc = None - try: - instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout - ) - instance._run_attempt() - except Exception as e: - found_exc = e - assert client.mutate_rows.call_count == 1 - assert type(found_exc) is exc_type - assert found_exc == expected_exception - assert len(instance.errors) == 2 - assert len(instance.remaining_indices) == 0 - - @pytest.mark.parametrize( - "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] - ) - def test_mutate_rows_exception(self, exc_type): - """exceptions raised from retryable should be raised in MutationsExceptionGroup""" - from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup - from google.cloud.bigtable.data.exceptions import FailedMutationEntryError - - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] - operation_timeout = 0.05 - expected_cause = exc_type("abort") - with mock.patch.object( - self._target_class(), "_run_attempt", mock.Mock() - ) as attempt_mock: - attempt_mock.side_effect = expected_cause - found_exc = None - try: - instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout - ) - instance.start() - except MutationsExceptionGroup as e: - found_exc = e - assert attempt_mock.call_count == 1 - assert len(found_exc.exceptions) == 2 - assert isinstance(found_exc.exceptions[0], FailedMutationEntryError) - assert isinstance(found_exc.exceptions[1], FailedMutationEntryError) - assert found_exc.exceptions[0].__cause__ == expected_cause - assert found_exc.exceptions[1].__cause__ == expected_cause - - @pytest.mark.parametrize( - "exc_type", [core_exceptions.DeadlineExceeded, RuntimeError] - ) - def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): - """If an exception fails but eventually passes, it should not raise an exception""" - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation()] - operation_timeout = 1 - expected_cause = exc_type("retry") - num_retries = 2 - with mock.patch.object( - self._target_class(), "_run_attempt", mock.Mock() - ) as attempt_mock: - attempt_mock.side_effect = [expected_cause] * num_retries + [None] - instance = self._make_one( - client, - table, - entries, - operation_timeout, - operation_timeout, - retryable_exceptions=(exc_type,), - ) - instance.start() - assert attempt_mock.call_count == num_retries + 1 - - def test_mutate_rows_incomplete_ignored(self): - """MutateRowsIncomplete exceptions should not be added to error list""" - from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup - from google.api_core.exceptions import DeadlineExceeded - - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation()] - operation_timeout = 0.05 - with mock.patch.object( - self._target_class(), "_run_attempt", mock.Mock() - ) as attempt_mock: - attempt_mock.side_effect = _MutateRowsIncomplete("ignored") - found_exc = None - try: - instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout - ) - instance.start() - except MutationsExceptionGroup as e: - found_exc = e - assert attempt_mock.call_count > 0 - assert len(found_exc.exceptions) == 1 - assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) - - def test_run_attempt_single_entry_success(self): - """Test mutating a single entry""" - mutation = self._make_mutation() - expected_timeout = 1.3 - mock_gapic_fn = self._make_mock_gapic({0: mutation}) - instance = self._make_one( - mutation_entries=[mutation], attempt_timeout=expected_timeout - ) - with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): - instance._run_attempt() - assert len(instance.remaining_indices) == 0 - assert mock_gapic_fn.call_count == 1 - (_, kwargs) = mock_gapic_fn.call_args - assert kwargs["timeout"] == expected_timeout - assert kwargs["entries"] == [mutation._to_pb()] - - def test_run_attempt_empty_request(self): - """Calling with no mutations should result in no API calls""" - mock_gapic_fn = self._make_mock_gapic([]) - instance = self._make_one(mutation_entries=[]) - instance._run_attempt() - assert mock_gapic_fn.call_count == 0 - - def test_run_attempt_partial_success_retryable(self): - """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" - from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - - success_mutation = self._make_mutation() - success_mutation_2 = self._make_mutation() - failure_mutation = self._make_mutation() - mutations = [success_mutation, failure_mutation, success_mutation_2] - mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) - instance = self._make_one(mutation_entries=mutations) - instance.is_retryable = lambda x: True - with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): - with pytest.raises(_MutateRowsIncomplete): - instance._run_attempt() - assert instance.remaining_indices == [1] - assert 0 not in instance.errors - assert len(instance.errors[1]) == 1 - assert instance.errors[1][0].grpc_status_code == 300 - assert 2 not in instance.errors - - def test_run_attempt_partial_success_non_retryable(self): - """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" - success_mutation = self._make_mutation() - success_mutation_2 = self._make_mutation() - failure_mutation = self._make_mutation() - mutations = [success_mutation, failure_mutation, success_mutation_2] - mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) - instance = self._make_one(mutation_entries=mutations) - instance.is_retryable = lambda x: False - with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): - instance._run_attempt() - assert instance.remaining_indices == [] - assert 0 not in instance.errors - assert len(instance.errors[1]) == 1 - assert instance.errors[1][0].grpc_status_code == 300 - assert 2 not in instance.errors - - -class TestReadRowsOperation(ABC): - """ - Tests helper functions in the ReadRowsOperation class - in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt - is tested in test_read_rows_acceptance test_client_read_rows, and conformance tests - """ - - @staticmethod - def _get_target_class(): - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - - return _ReadRowsOperation - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - - def test_ctor(self): - from google.cloud.bigtable.data import ReadRowsQuery - - row_limit = 91 - query = ReadRowsQuery(limit=row_limit) - client = mock.Mock() - client.read_rows = mock.Mock() - client.read_rows.return_value = None - table = mock.Mock() - table._client = client - table.table_name = "test_table" - table.app_profile_id = "test_profile" - expected_operation_timeout = 42 - expected_request_timeout = 44 - time_gen_mock = mock.Mock() - with mock.patch( - "google.cloud.bigtable.data._helpers._attempt_timeout_generator", - time_gen_mock, - ): - instance = self._make_one( - query, - table, - operation_timeout=expected_operation_timeout, - attempt_timeout=expected_request_timeout, - ) - assert time_gen_mock.call_count == 1 - time_gen_mock.assert_called_once_with( - expected_request_timeout, expected_operation_timeout - ) - assert instance._last_yielded_row_key is None - assert instance._remaining_count == row_limit - assert instance.operation_timeout == expected_operation_timeout - assert client.read_rows.call_count == 0 - assert instance._metadata == [ - ( - "x-goog-request-params", - "table_name=test_table&app_profile_id=test_profile", - ) - ] - assert instance.request.table_name == table.table_name - assert instance.request.app_profile_id == table.app_profile_id - assert instance.request.rows_limit == row_limit - - @pytest.mark.parametrize( - "in_keys,last_key,expected", - [ - (["b", "c", "d"], "a", ["b", "c", "d"]), - (["a", "b", "c"], "b", ["c"]), - (["a", "b", "c"], "c", []), - (["a", "b", "c"], "d", []), - (["d", "c", "b", "a"], "b", ["d", "c"]), - ], - ) - def test_revise_request_rowset_keys(self, in_keys, last_key, expected): - from google.cloud.bigtable_v2.types import RowSet as RowSetPB - from google.cloud.bigtable_v2.types import RowRange as RowRangePB - - in_keys = [key.encode("utf-8") for key in in_keys] - expected = [key.encode("utf-8") for key in expected] - last_key = last_key.encode("utf-8") - sample_range = RowRangePB(start_key_open=last_key) - row_set = RowSetPB(row_keys=in_keys, row_ranges=[sample_range]) - revised = self._get_target_class()._revise_request_rowset(row_set, last_key) - assert revised.row_keys == expected - assert revised.row_ranges == [sample_range] - - @pytest.mark.parametrize( - "in_ranges,last_key,expected", - [ - ( - [{"start_key_open": "b", "end_key_closed": "d"}], - "a", - [{"start_key_open": "b", "end_key_closed": "d"}], - ), - ( - [{"start_key_closed": "b", "end_key_closed": "d"}], - "a", - [{"start_key_closed": "b", "end_key_closed": "d"}], - ), - ( - [{"start_key_open": "a", "end_key_closed": "d"}], - "b", - [{"start_key_open": "b", "end_key_closed": "d"}], - ), - ( - [{"start_key_closed": "a", "end_key_open": "d"}], - "b", - [{"start_key_open": "b", "end_key_open": "d"}], - ), - ( - [{"start_key_closed": "b", "end_key_closed": "d"}], - "b", - [{"start_key_open": "b", "end_key_closed": "d"}], - ), - ([{"start_key_closed": "b", "end_key_closed": "d"}], "d", []), - ([{"start_key_closed": "b", "end_key_open": "d"}], "d", []), - ([{"start_key_closed": "b", "end_key_closed": "d"}], "e", []), - ([{"start_key_closed": "b"}], "z", [{"start_key_open": "z"}]), - ([{"start_key_closed": "b"}], "a", [{"start_key_closed": "b"}]), - ( - [{"end_key_closed": "z"}], - "a", - [{"start_key_open": "a", "end_key_closed": "z"}], - ), - ( - [{"end_key_open": "z"}], - "a", - [{"start_key_open": "a", "end_key_open": "z"}], - ), - ], - ) - def test_revise_request_rowset_ranges(self, in_ranges, last_key, expected): - from google.cloud.bigtable_v2.types import RowSet as RowSetPB - from google.cloud.bigtable_v2.types import RowRange as RowRangePB - - next_key = (last_key + "a").encode("utf-8") - last_key = last_key.encode("utf-8") - in_ranges = [ - RowRangePB(**{k: v.encode("utf-8") for (k, v) in r.items()}) - for r in in_ranges - ] - expected = [ - RowRangePB(**{k: v.encode("utf-8") for (k, v) in r.items()}) - for r in expected - ] - row_set = RowSetPB(row_ranges=in_ranges, row_keys=[next_key]) - revised = self._get_target_class()._revise_request_rowset(row_set, last_key) - assert revised.row_keys == [next_key] - assert revised.row_ranges == expected - - @pytest.mark.parametrize("last_key", ["a", "b", "c"]) - def test_revise_request_full_table(self, last_key): - from google.cloud.bigtable_v2.types import RowSet as RowSetPB - from google.cloud.bigtable_v2.types import RowRange as RowRangePB - - last_key = last_key.encode("utf-8") - row_set = RowSetPB() - for selected_set in [row_set, None]: - revised = self._get_target_class()._revise_request_rowset( - selected_set, last_key - ) - assert revised.row_keys == [] - assert len(revised.row_ranges) == 1 - assert revised.row_ranges[0] == RowRangePB(start_key_open=last_key) - - def test_revise_to_empty_rowset(self): - """revising to an empty rowset should raise error""" - from google.cloud.bigtable.data.exceptions import _RowSetComplete - from google.cloud.bigtable_v2.types import RowSet as RowSetPB - from google.cloud.bigtable_v2.types import RowRange as RowRangePB - - row_keys = [b"a", b"b", b"c"] - row_range = RowRangePB(end_key_open=b"c") - row_set = RowSetPB(row_keys=row_keys, row_ranges=[row_range]) - with pytest.raises(_RowSetComplete): - self._get_target_class()._revise_request_rowset(row_set, b"d") - - @pytest.mark.parametrize( - "start_limit,emit_num,expected_limit", - [ - (10, 0, 10), - (10, 1, 9), - (10, 10, 0), - (None, 10, None), - (None, 0, None), - (4, 2, 2), - ], - ) - def test_revise_limit(self, start_limit, emit_num, expected_limit): - """ - revise_limit should revise the request's limit field - - if limit is 0 (unlimited), it should never be revised - - if start_limit-emit_num == 0, the request should end early - - if the number emitted exceeds the new limit, an exception should - should be raised (tested in test_revise_limit_over_limit) - """ - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable_v2.types import ReadRowsResponse - - def awaitable_stream(): - def mock_stream(): - for i in range(emit_num): - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk( - row_key=str(i).encode(), - family_name="b", - qualifier=b"c", - value=b"d", - commit_row=True, - ) - ] - ) - - return mock_stream() - - query = ReadRowsQuery(limit=start_limit) - table = mock.Mock() - table.table_name = "table_name" - table.app_profile_id = "app_profile_id" - instance = self._make_one(query, table, 10, 10) - assert instance._remaining_count == start_limit - for val in instance.chunk_stream(awaitable_stream()): - pass - assert instance._remaining_count == expected_limit - - @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) - def test_revise_limit_over_limit(self, start_limit, emit_num): - """ - Should raise runtime error if we get in state where emit_num > start_num - (unless start_num == 0, which represents unlimited) - """ - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable_v2.types import ReadRowsResponse - from google.cloud.bigtable.data.exceptions import InvalidChunk - - def awaitable_stream(): - def mock_stream(): - for i in range(emit_num): - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk( - row_key=str(i).encode(), - family_name="b", - qualifier=b"c", - value=b"d", - commit_row=True, - ) - ] - ) - - return mock_stream() - - query = ReadRowsQuery(limit=start_limit) - table = mock.Mock() - table.table_name = "table_name" - table.app_profile_id = "app_profile_id" - instance = self._make_one(query, table, 10, 10) - assert instance._remaining_count == start_limit - with pytest.raises(InvalidChunk) as e: - for val in instance.chunk_stream(awaitable_stream()): - pass - assert "emit count exceeds row limit" in str(e.value) - - def test_close(self): - """ - should be able to close a stream safely with aclose. - Closed generators should raise StopIteration on next yield - """ - - def mock_stream(): - while True: - yield 1 - - with mock.patch.object( - self._get_target_class(), "_read_rows_attempt" - ) as mock_attempt: - instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) - wrapped_gen = mock_stream() - mock_attempt.return_value = wrapped_gen - gen = instance.start_operation() - gen.__next__() - gen.close() - with pytest.raises(StopIteration): - gen.__next__() - gen.close() - with pytest.raises(StopIteration): - wrapped_gen.__next__() - - def test_retryable_ignore_repeated_rows(self): - """Duplicate rows should cause an invalid chunk error""" - from google.cloud.bigtable.data.exceptions import InvalidChunk - from google.cloud.bigtable_v2.types import ReadRowsResponse - - row_key = b"duplicate" - - def mock_awaitable_stream(): - def mock_stream(): - while True: - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) - ] - ) - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) - ] - ) - - return mock_stream() - - instance = mock.Mock() - instance._last_yielded_row_key = None - instance._remaining_count = None - stream = self._get_target_class().chunk_stream( - instance, mock_awaitable_stream() - ) - stream.__next__() - with pytest.raises(InvalidChunk) as exc: - stream.__next__() - assert "row keys should be strictly increasing" in str(exc.value) - - -class Test_FlowControl(ABC): - @staticmethod - def _target_class(): - from google.cloud.bigtable.data._sync.mutations_batcher import _FlowControl - - return _FlowControl - - def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): - return self._target_class()(max_mutation_count, max_mutation_bytes) - - @staticmethod - def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation - - def test_ctor(self): - max_mutation_count = 9 - max_mutation_bytes = 19 - instance = self._make_one(max_mutation_count, max_mutation_bytes) - assert instance._max_mutation_count == max_mutation_count - assert instance._max_mutation_bytes == max_mutation_bytes - assert instance._in_flight_mutation_count == 0 - assert instance._in_flight_mutation_bytes == 0 - assert isinstance(instance._capacity_condition, threading.Condition) - - def test_ctor_invalid_values(self): - """Test that values are positive, and fit within expected limits""" - with pytest.raises(ValueError) as e: - self._make_one(0, 1) - assert "max_mutation_count must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - self._make_one(1, 0) - assert "max_mutation_bytes must be greater than 0" in str(e.value) - - @pytest.mark.parametrize( - "max_count,max_size,existing_count,existing_size,new_count,new_size,expected", - [ - (1, 1, 0, 0, 0, 0, True), - (1, 1, 1, 1, 1, 1, False), - (10, 10, 0, 0, 0, 0, True), - (10, 10, 0, 0, 9, 9, True), - (10, 10, 0, 0, 11, 9, True), - (10, 10, 0, 1, 11, 9, True), - (10, 10, 1, 0, 11, 9, False), - (10, 10, 0, 0, 9, 11, True), - (10, 10, 1, 0, 9, 11, True), - (10, 10, 0, 1, 9, 11, False), - (10, 1, 0, 0, 1, 0, True), - (1, 10, 0, 0, 0, 8, True), - (float("inf"), float("inf"), 0, 0, 10000000000.0, 10000000000.0, True), - (8, 8, 0, 0, 10000000000.0, 10000000000.0, True), - (12, 12, 6, 6, 5, 5, True), - (12, 12, 5, 5, 6, 6, True), - (12, 12, 6, 6, 6, 6, True), - (12, 12, 6, 6, 7, 7, False), - (12, 12, 0, 0, 13, 13, True), - (12, 12, 12, 0, 0, 13, True), - (12, 12, 0, 12, 13, 0, True), - (12, 12, 1, 1, 13, 13, False), - (12, 12, 1, 1, 0, 13, False), - (12, 12, 1, 1, 13, 0, False), - ], - ) - def test__has_capacity( - self, - max_count, - max_size, - existing_count, - existing_size, - new_count, - new_size, - expected, - ): - """_has_capacity should return True if the new mutation will will not exceed the max count or size""" - instance = self._make_one(max_count, max_size) - instance._in_flight_mutation_count = existing_count - instance._in_flight_mutation_bytes = existing_size - assert instance._has_capacity(new_count, new_size) == expected - - @pytest.mark.parametrize( - "existing_count,existing_size,added_count,added_size,new_count,new_size", - [ - (0, 0, 0, 0, 0, 0), - (2, 2, 1, 1, 1, 1), - (2, 0, 1, 0, 1, 0), - (0, 2, 0, 1, 0, 1), - (10, 10, 0, 0, 10, 10), - (10, 10, 5, 5, 5, 5), - (0, 0, 1, 1, -1, -1), - ], - ) - def test_remove_from_flow_value_update( - self, - existing_count, - existing_size, - added_count, - added_size, - new_count, - new_size, - ): - """completed mutations should lower the inflight values""" - instance = self._make_one() - instance._in_flight_mutation_count = existing_count - instance._in_flight_mutation_bytes = existing_size - mutation = self._make_mutation(added_count, added_size) - instance.remove_from_flow(mutation) - assert instance._in_flight_mutation_count == new_count - assert instance._in_flight_mutation_bytes == new_size - - def test__remove_from_flow_unlock(self): - """capacity condition should notify after mutation is complete""" - import inspect - - instance = self._make_one(10, 10) - instance._in_flight_mutation_count = 10 - instance._in_flight_mutation_bytes = 10 - - def task_routine(): - with instance._capacity_condition: - instance._capacity_condition.wait_for( - lambda: instance._has_capacity(1, 1) - ) - - if inspect.iscoroutinefunction(task_routine): - task = threading.Thread(task_routine()) - task_alive = lambda: not task.done() - else: - import threading - - thread = threading.Thread(target=task_routine) - thread.start() - task_alive = thread.is_alive - time.sleep(0.05) - assert task_alive() is True - mutation = self._make_mutation(count=0, size=5) - instance.remove_from_flow([mutation]) - time.sleep(0.05) - assert instance._in_flight_mutation_count == 10 - assert instance._in_flight_mutation_bytes == 5 - assert task_alive() is True - instance._in_flight_mutation_bytes = 10 - mutation = self._make_mutation(count=5, size=0) - instance.remove_from_flow([mutation]) - time.sleep(0.05) - assert instance._in_flight_mutation_count == 5 - assert instance._in_flight_mutation_bytes == 10 - assert task_alive() is True - instance._in_flight_mutation_count = 10 - mutation = self._make_mutation(count=5, size=5) - instance.remove_from_flow([mutation]) - time.sleep(0.05) - assert instance._in_flight_mutation_count == 5 - assert instance._in_flight_mutation_bytes == 5 - assert task_alive() is False - - @pytest.mark.parametrize( - "mutations,count_cap,size_cap,expected_results", - [ - ([(5, 5), (1, 1), (1, 1)], 10, 10, [[(5, 5), (1, 1), (1, 1)]]), - ([(1, 1), (1, 1), (1, 1)], 1, 1, [[(1, 1)], [(1, 1)], [(1, 1)]]), - ([(1, 1), (1, 1), (1, 1)], 2, 10, [[(1, 1), (1, 1)], [(1, 1)]]), - ([(1, 1), (1, 1), (1, 1)], 10, 2, [[(1, 1), (1, 1)], [(1, 1)]]), - ( - [(1, 1), (5, 5), (4, 1), (1, 4), (1, 1)], - 5, - 5, - [[(1, 1)], [(5, 5)], [(4, 1), (1, 4)], [(1, 1)]], - ), - ], - ) - def test_add_to_flow(self, mutations, count_cap, size_cap, expected_results): - """Test batching with various flow control settings""" - mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] - instance = self._make_one(count_cap, size_cap) - i = 0 - for batch in instance.add_to_flow(mutation_objs): - expected_batch = expected_results[i] - assert len(batch) == len(expected_batch) - for j in range(len(expected_batch)): - assert len(batch[j].mutations) == expected_batch[j][0] - assert batch[j].size() == expected_batch[j][1] - instance.remove_from_flow(batch) - i += 1 - assert i == len(expected_results) - - @pytest.mark.parametrize( - "mutations,max_limit,expected_results", - [ - ([(1, 1)] * 11, 10, [[(1, 1)] * 10, [(1, 1)]]), - ([(1, 1)] * 10, 1, [[(1, 1)] for _ in range(10)]), - ([(1, 1)] * 10, 2, [[(1, 1), (1, 1)] for _ in range(5)]), - ], - ) - def test_add_to_flow_max_mutation_limits( - self, mutations, max_limit, expected_results - ): - """ - Test flow control running up against the max API limit - Should submit request early, even if the flow control has room for more - """ - async_patch = mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", - max_limit, - ) - sync_patch = mock.patch( - "google.cloud.bigtable.data._sync._autogen._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", - max_limit, - ) - with async_patch, sync_patch: - mutation_objs = [ - self._make_mutation(count=m[0], size=m[1]) for m in mutations - ] - instance = self._make_one(float("inf"), float("inf")) - i = 0 - for batch in instance.add_to_flow(mutation_objs): - expected_batch = expected_results[i] - assert len(batch) == len(expected_batch) - for j in range(len(expected_batch)): - assert len(batch[j].mutations) == expected_batch[j][0] - assert batch[j].size() == expected_batch[j][1] - instance.remove_from_flow(batch) - i += 1 - assert i == len(expected_results) - - def test_add_to_flow_oversize(self): - """mutations over the flow control limits should still be accepted""" - instance = self._make_one(2, 3) - large_size_mutation = self._make_mutation(count=1, size=10) - large_count_mutation = self._make_mutation(count=10, size=1) - results = [out for out in instance.add_to_flow([large_size_mutation])] - assert len(results) == 1 - instance.remove_from_flow(results[0]) - count_results = [out for out in instance.add_to_flow(large_count_mutation)] - assert len(count_results) == 1 - - -class TestMutationsBatcher(ABC): - def _get_target_class(self): - from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher - - return MutationsBatcher - - @staticmethod - def is_async(): - return False - - def _make_one(self, table=None, **kwargs): - from google.api_core.exceptions import DeadlineExceeded - from google.api_core.exceptions import ServiceUnavailable - - if table is None: - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 10 - table.default_mutate_rows_retryable_errors = ( - DeadlineExceeded, - ServiceUnavailable, - ) - return self._get_target_class()(table, **kwargs) - - @staticmethod - def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation - - def test_ctor_defaults(self): - with mock.patch.object( - self._get_target_class(), - "_timer_routine", - return_value=concurrent.futures.Future(), - ) as flush_timer_mock: - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = [Exception] - with self._make_one(table) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._max_mutation_count == 100000 - assert instance._flow_control._max_mutation_bytes == 104857600 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert ( - instance._operation_timeout - == table.default_mutate_rows_operation_timeout - ) - assert ( - instance._attempt_timeout - == table.default_mutate_rows_attempt_timeout - ) - assert ( - instance._retryable_errors - == table.default_mutate_rows_retryable_errors - ) - time.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == 5 - assert isinstance(instance._flush_timer, concurrent.futures.Future) - - def test_ctor_explicit(self): - """Test with explicit parameters""" - with mock.patch.object( - self._get_target_class(), - "_timer_routine", - return_value=concurrent.futures.Future(), - ) as flush_timer_mock: - table = mock.Mock() - flush_interval = 20 - flush_limit_count = 17 - flush_limit_bytes = 19 - flow_control_max_mutation_count = 1001 - flow_control_max_bytes = 12 - operation_timeout = 11 - attempt_timeout = 2 - retryable_errors = [Exception] - with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - flow_control_max_mutation_count=flow_control_max_mutation_count, - flow_control_max_bytes=flow_control_max_bytes, - batch_operation_timeout=operation_timeout, - batch_attempt_timeout=attempt_timeout, - batch_retryable_errors=retryable_errors, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert ( - instance._flow_control._max_mutation_count - == flow_control_max_mutation_count - ) - assert ( - instance._flow_control._max_mutation_bytes == flow_control_max_bytes - ) - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert instance._operation_timeout == operation_timeout - assert instance._attempt_timeout == attempt_timeout - assert instance._retryable_errors == retryable_errors - time.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == flush_interval - assert isinstance(instance._flush_timer, concurrent.futures.Future) - - def test_ctor_no_flush_limits(self): - """Test with None for flush limits""" - with mock.patch.object( - self._get_target_class(), - "_timer_routine", - return_value=concurrent.futures.Future(), - ) as flush_timer_mock: - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = () - flush_interval = None - flush_limit_count = None - flush_limit_bytes = None - with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._staged_entries == [] - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - time.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] is None - assert isinstance(instance._flush_timer, concurrent.futures.Future) - - def test_ctor_invalid_values(self): - """Test that timeout values are positive, and fit within expected limits""" - with pytest.raises(ValueError) as e: - self._make_one(batch_operation_timeout=-1) - assert "operation_timeout must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - self._make_one(batch_attempt_timeout=-1) - assert "attempt_timeout must be greater than 0" in str(e.value) - - def test_default_argument_consistency(self): - """ - We supply default arguments in MutationsBatcherAsync.__init__, and in - table.mutations_batcher. Make sure any changes to defaults are applied to - both places - """ - import inspect - - get_batcher_signature = dict( - inspect.signature(Table.mutations_batcher).parameters - ) - get_batcher_signature.pop("self") - batcher_init_signature = dict( - inspect.signature(self._get_target_class()).parameters - ) - batcher_init_signature.pop("table") - assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) - assert len(get_batcher_signature) == 8 - assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) - for arg_name in get_batcher_signature.keys(): - assert ( - get_batcher_signature[arg_name].default - == batcher_init_signature[arg_name].default - ) - - @pytest.mark.parametrize("input_val", [None, 0, -1]) - def test__start_flush_timer_w_empty_input(self, input_val): - """Empty/invalid timer should return immediately""" - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: - with self._make_one() as instance: - if self.is_async(): - (sleep_obj, sleep_method) = (asyncio, "wait_for") - else: - (sleep_obj, sleep_method) = (instance._closed, "wait") - with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: - result = instance._timer_routine(input_val) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 - assert result is None - - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test__start_flush_timer_call_when_closed(self): - """closed batcher's timer should return immediately""" - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: - with self._make_one() as instance: - instance.close() - flush_mock.reset_mock() - if self.is_async(): - (sleep_obj, sleep_method) = (asyncio, "wait_for") - else: - (sleep_obj, sleep_method) = (instance._closed, "wait") - with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: - instance._timer_routine(10) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 - - @pytest.mark.parametrize("num_staged", [0, 1, 10]) - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test__flush_timer(self, num_staged): - """Timer should continue to call _schedule_flush in a loop""" - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: - expected_sleep = 12 - with self._make_one(flush_interval=expected_sleep) as instance: - loop_num = 3 - instance._staged_entries = [mock.Mock()] * num_staged - if self.is_async(): - (sleep_obj, sleep_method) = (asyncio, "wait_for") - else: - (sleep_obj, sleep_method) = (instance._closed, "wait") - with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: - sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] - with pytest.raises(TabError): - self._get_target_class()._timer_routine( - instance, expected_sleep - ) - instance._flush_timer = concurrent.futures.Future() - assert sleep_mock.call_count == loop_num + 1 - sleep_kwargs = sleep_mock.call_args[1] - assert sleep_kwargs["timeout"] == expected_sleep - assert flush_mock.call_count == (0 if num_staged == 0 else loop_num) - - def test__flush_timer_close(self): - """Timer should continue terminate after close""" - with mock.patch.object(self._get_target_class(), "_schedule_flush"): - with self._make_one() as instance: - with mock.patch("asyncio.sleep"): - time.sleep(0.5) - assert instance._flush_timer.done() is False - instance.close() - time.sleep(0.1) - assert instance._flush_timer.done() is True - - def test_append_closed(self): - """Should raise exception""" - instance = self._make_one() - instance.close() - with pytest.raises(RuntimeError): - instance.append(mock.Mock()) - - def test_append_wrong_mutation(self): - """ - Mutation objects should raise an exception. - Only support RowMutationEntry - """ - from google.cloud.bigtable.data.mutations import DeleteAllFromRow - - with self._make_one() as instance: - expected_error = "invalid mutation type: DeleteAllFromRow. Only RowMutationEntry objects are supported by batcher" - with pytest.raises(ValueError) as e: - instance.append(DeleteAllFromRow()) - assert str(e.value) == expected_error - - def test_append_outside_flow_limits(self): - """entries larger than mutation limits are still processed""" - with self._make_one( - flow_control_max_mutation_count=1, flow_control_max_bytes=1 - ) as instance: - oversized_entry = self._make_mutation(count=0, size=2) - instance.append(oversized_entry) - assert instance._staged_entries == [oversized_entry] - assert instance._staged_count == 0 - assert instance._staged_bytes == 2 - instance._staged_entries = [] - with self._make_one( - flow_control_max_mutation_count=1, flow_control_max_bytes=1 - ) as instance: - overcount_entry = self._make_mutation(count=2, size=0) - instance.append(overcount_entry) - assert instance._staged_entries == [overcount_entry] - assert instance._staged_count == 2 - assert instance._staged_bytes == 0 - instance._staged_entries = [] - - def test_append_flush_runs_after_limit_hit(self): - """ - If the user appends a bunch of entries above the flush limits back-to-back, - it should still flush in a single task - """ - with mock.patch.object( - self._get_target_class(), "_execute_mutate_rows" - ) as op_mock: - with self._make_one(flush_limit_bytes=100) as instance: - - def mock_call(*args, **kwargs): - return [] - - op_mock.side_effect = mock_call - instance.append(self._make_mutation(size=99)) - num_entries = 10 - for _ in range(num_entries): - instance.append(self._make_mutation(size=1)) - instance._wait_for_batch_results(*instance._flush_jobs) - assert op_mock.call_count == 1 - sent_batch = op_mock.call_args[0][0] - assert len(sent_batch) == 2 - assert len(instance._staged_entries) == num_entries - 1 - - @pytest.mark.parametrize( - "flush_count,flush_bytes,mutation_count,mutation_bytes,expect_flush", - [ - (10, 10, 1, 1, False), - (10, 10, 9, 9, False), - (10, 10, 10, 1, True), - (10, 10, 1, 10, True), - (10, 10, 10, 10, True), - (1, 1, 10, 10, True), - (1, 1, 0, 0, False), - ], - ) - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test_append( - self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush - ): - """test appending different mutations, and checking if it causes a flush""" - with self._make_one( - flush_limit_mutation_count=flush_count, flush_limit_bytes=flush_bytes - ) as instance: - assert instance._staged_count == 0 - assert instance._staged_bytes == 0 - assert instance._staged_entries == [] - mutation = self._make_mutation(count=mutation_count, size=mutation_bytes) - with mock.patch.object(instance, "_schedule_flush") as flush_mock: - instance.append(mutation) - assert flush_mock.call_count == bool(expect_flush) - assert instance._staged_count == mutation_count - assert instance._staged_bytes == mutation_bytes - assert instance._staged_entries == [mutation] - instance._staged_entries = [] - - def test_append_multiple_sequentially(self): - """Append multiple mutations""" - with self._make_one( - flush_limit_mutation_count=8, flush_limit_bytes=8 - ) as instance: - assert instance._staged_count == 0 - assert instance._staged_bytes == 0 - assert instance._staged_entries == [] - mutation = self._make_mutation(count=2, size=3) - with mock.patch.object(instance, "_schedule_flush") as flush_mock: - instance.append(mutation) - assert flush_mock.call_count == 0 - assert instance._staged_count == 2 - assert instance._staged_bytes == 3 - assert len(instance._staged_entries) == 1 - instance.append(mutation) - assert flush_mock.call_count == 0 - assert instance._staged_count == 4 - assert instance._staged_bytes == 6 - assert len(instance._staged_entries) == 2 - instance.append(mutation) - assert flush_mock.call_count == 1 - assert instance._staged_count == 6 - assert instance._staged_bytes == 9 - assert len(instance._staged_entries) == 3 - instance._staged_entries = [] - - def test_flush_flow_control_concurrent_requests(self): - """requests should happen in parallel if flow control breaks up single flush into batches""" - import time - - num_calls = 10 - fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] - with self._make_one(flow_control_max_mutation_count=1) as instance: - with mock.patch.object( - instance, "_execute_mutate_rows", mock.Mock() - ) as op_mock: - - def mock_call(*args, **kwargs): - time.sleep(0.1) - return [] - - op_mock.side_effect = mock_call - start_time = time.monotonic() - instance._staged_entries = fake_mutations - instance._schedule_flush() - time.sleep(0.01) - for i in range(num_calls): - instance._flow_control.remove_from_flow( - [self._make_mutation(count=1)] - ) - time.sleep(0.01) - instance._wait_for_batch_results(*instance._flush_jobs) - duration = time.monotonic() - start_time - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert duration < 0.5 - assert op_mock.call_count == num_calls - - def test_schedule_flush_no_mutations(self): - """schedule flush should return None if no staged mutations""" - with self._make_one() as instance: - with mock.patch.object(instance, "_flush_internal") as flush_mock: - for i in range(3): - assert instance._schedule_flush() is None - assert flush_mock.call_count == 0 - - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test_schedule_flush_with_mutations(self): - """if new mutations exist, should add a new flush task to _flush_jobs""" - with self._make_one() as instance: - with mock.patch.object(instance, "_flush_internal") as flush_mock: - if not self.is_async(): - flush_mock.side_effect = lambda x: time.sleep(0.1) - for i in range(1, 4): - mutation = mock.Mock() - instance._staged_entries = [mutation] - instance._schedule_flush() - assert instance._staged_entries == [] - time.sleep(0) - assert instance._staged_entries == [] - assert instance._staged_count == 0 - assert instance._staged_bytes == 0 - assert flush_mock.call_count == 1 - flush_mock.reset_mock() - - def test__flush_internal(self): - """ - _flush_internal should: - - await previous flush call - - delegate batching to _flow_control - - call _execute_mutate_rows on each batch - - update self.exceptions and self._entries_processed_since_last_raise - """ - num_entries = 10 - with self._make_one() as instance: - with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: - with mock.patch.object( - instance._flow_control, "add_to_flow" - ) as flow_mock: - - def gen(x): - yield x - - flow_mock.side_effect = lambda x: gen(x) - mutations = [self._make_mutation(count=1, size=1)] * num_entries - instance._flush_internal(mutations) - assert instance._entries_processed_since_last_raise == num_entries - assert execute_mock.call_count == 1 - assert flow_mock.call_count == 1 - instance._oldest_exceptions.clear() - instance._newest_exceptions.clear() - - def test_flush_clears_job_list(self): - """ - a job should be added to _flush_jobs when _schedule_flush is called, - and removed when it completes - """ - with self._make_one() as instance: - with mock.patch.object( - instance, "_flush_internal", mock.Mock() - ) as flush_mock: - if not self.is_async(): - flush_mock.side_effect = lambda x: time.sleep(0.1) - mutations = [self._make_mutation(count=1, size=1)] - instance._staged_entries = mutations - assert instance._flush_jobs == set() - new_job = instance._schedule_flush() - assert instance._flush_jobs == {new_job} - if self.is_async(): - new_job - else: - new_job.result() - assert instance._flush_jobs == set() - - @pytest.mark.parametrize( - "num_starting,num_new_errors,expected_total_errors", - [ - (0, 0, 0), - (0, 1, 1), - (0, 2, 2), - (1, 0, 1), - (1, 1, 2), - (10, 2, 12), - (10, 20, 20), - ], - ) - def test__flush_internal_with_errors( - self, num_starting, num_new_errors, expected_total_errors - ): - """errors returned from _execute_mutate_rows should be added to internal exceptions""" - from google.cloud.bigtable.data import exceptions - - num_entries = 10 - expected_errors = [ - exceptions.FailedMutationEntryError(mock.Mock(), mock.Mock(), ValueError()) - ] * num_new_errors - with self._make_one() as instance: - instance._oldest_exceptions = [mock.Mock()] * num_starting - with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: - execute_mock.return_value = expected_errors - with mock.patch.object( - instance._flow_control, "add_to_flow" - ) as flow_mock: - - def gen(x): - yield x - - flow_mock.side_effect = lambda x: gen(x) - mutations = [self._make_mutation(count=1, size=1)] * num_entries - instance._flush_internal(mutations) - assert instance._entries_processed_since_last_raise == num_entries - assert execute_mock.call_count == 1 - assert flow_mock.call_count == 1 - found_exceptions = instance._oldest_exceptions + list( - instance._newest_exceptions - ) - assert len(found_exceptions) == expected_total_errors - for i in range(num_starting, expected_total_errors): - assert found_exceptions[i] == expected_errors[i - num_starting] - assert found_exceptions[i].index is None - instance._oldest_exceptions.clear() - instance._newest_exceptions.clear() - - def _mock_gapic_return(self, num=5): - from google.cloud.bigtable_v2.types import MutateRowsResponse - from google.rpc import status_pb2 - - def gen(num): - for i in range(num): - entry = MutateRowsResponse.Entry( - index=i, status=status_pb2.Status(code=0) - ) - yield MutateRowsResponse(entries=[entry]) - - return gen(num) - - def test_timer_flush_end_to_end(self): - """Flush should automatically trigger after flush_interval""" - num_nutations = 10 - mutations = [self._make_mutation(count=2, size=2)] * num_nutations - with self._make_one(flush_interval=0.05) as instance: - instance._table.default_operation_timeout = 10 - instance._table.default_attempt_timeout = 9 - with mock.patch.object( - instance._table.client._gapic_client, "mutate_rows" - ) as gapic_mock: - gapic_mock.side_effect = ( - lambda *args, **kwargs: self._mock_gapic_return(num_nutations) - ) - for m in mutations: - instance.append(m) - assert instance._entries_processed_since_last_raise == 0 - time.sleep(0.1) - assert instance._entries_processed_since_last_raise == num_nutations - - def test__execute_mutate_rows(self): - if self.is_async(): - mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" - else: - mutate_path = "_sync._mutate_rows._MutateRowsOperation" - with mock.patch(f"google.cloud.bigtable.data.{mutate_path}") as mutate_rows: - mutate_rows.return_value = mock.Mock() - start_operation = mutate_rows().start - table = mock.Mock() - table.table_name = "test-table" - table.app_profile_id = "test-app-profile" - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - with self._make_one(table) as instance: - batch = [self._make_mutation()] - result = instance._execute_mutate_rows(batch) - assert start_operation.call_count == 1 - (args, kwargs) = mutate_rows.call_args - assert args[0] == table.client._gapic_client - assert args[1] == table - assert args[2] == batch - kwargs["operation_timeout"] == 17 - kwargs["attempt_timeout"] == 13 - assert result == [] - - def test__execute_mutate_rows_returns_errors(self): - """Errors from operation should be retruned as list""" - from google.cloud.bigtable.data.exceptions import ( - MutationsExceptionGroup, - FailedMutationEntryError, - ) - - if self.is_async(): - mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" - else: - mutate_path = "_sync._mutate_rows._MutateRowsOperation" - with mock.patch( - f"google.cloud.bigtable.data.{mutate_path}.start" - ) as mutate_rows: - err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) - err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) - mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - with self._make_one(table) as instance: - batch = [self._make_mutation()] - result = instance._execute_mutate_rows(batch) - assert len(result) == 2 - assert result[0] == err1 - assert result[1] == err2 - assert result[0].index is None - assert result[1].index is None - - def test__raise_exceptions(self): - """Raise exceptions and reset error state""" - from google.cloud.bigtable.data import exceptions - - expected_total = 1201 - expected_exceptions = [RuntimeError("mock")] * 3 - with self._make_one() as instance: - instance._oldest_exceptions = expected_exceptions - instance._entries_processed_since_last_raise = expected_total - try: - instance._raise_exceptions() - except exceptions.MutationsExceptionGroup as exc: - assert list(exc.exceptions) == expected_exceptions - assert str(expected_total) in str(exc) - assert instance._entries_processed_since_last_raise == 0 - (instance._oldest_exceptions, instance._newest_exceptions) = ([], []) - instance._raise_exceptions() - - def test___aenter__(self): - """Should return self""" - with self._make_one() as instance: - assert instance.__enter__() == instance - - def test___aexit__(self): - """aexit should call close""" - with self._make_one() as instance: - with mock.patch.object(instance, "close") as close_mock: - instance.__exit__(None, None, None) - assert close_mock.call_count == 1 - - def test_close(self): - """Should clean up all resources""" - with self._make_one() as instance: - with mock.patch.object(instance, "_schedule_flush") as flush_mock: - with mock.patch.object(instance, "_raise_exceptions") as raise_mock: - instance.close() - assert instance.closed is True - assert instance._flush_timer.done() is True - assert instance._flush_jobs == set() - assert flush_mock.call_count == 1 - assert raise_mock.call_count == 1 - - def test_close_w_exceptions(self): - """Raise exceptions on close""" - from google.cloud.bigtable.data import exceptions - - expected_total = 10 - expected_exceptions = [RuntimeError("mock")] - with self._make_one() as instance: - instance._oldest_exceptions = expected_exceptions - instance._entries_processed_since_last_raise = expected_total - try: - instance.close() - except exceptions.MutationsExceptionGroup as exc: - assert list(exc.exceptions) == expected_exceptions - assert str(expected_total) in str(exc) - assert instance._entries_processed_since_last_raise == 0 - (instance._oldest_exceptions, instance._newest_exceptions) = ([], []) - - def test__on_exit(self, recwarn): - """Should raise warnings if unflushed mutations exist""" - with self._make_one() as instance: - instance._on_exit() - assert len(recwarn) == 0 - num_left = 4 - instance._staged_entries = [mock.Mock()] * num_left - with pytest.warns(UserWarning) as w: - instance._on_exit() - assert len(w) == 1 - assert "unflushed mutations" in str(w[0].message).lower() - assert str(num_left) in str(w[0].message) - instance._closed.set() - instance._on_exit() - assert len(recwarn) == 0 - instance._staged_entries = [] - - def test_atexit_registration(self): - """Should run _on_exit on program termination""" - import atexit - - with mock.patch.object(atexit, "register") as register_mock: - assert register_mock.call_count == 0 - with self._make_one(): - assert register_mock.call_count == 1 - - def test_timeout_args_passed(self): - """ - batch_operation_timeout and batch_attempt_timeout should be used - in api calls - """ - if self.is_async(): - mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" - else: - mutate_path = "_sync._mutate_rows._MutateRowsOperation" - with mock.patch( - f"google.cloud.bigtable.data.{mutate_path}", return_value=mock.Mock() - ) as mutate_rows: - expected_operation_timeout = 17 - expected_attempt_timeout = 13 - with self._make_one( - batch_operation_timeout=expected_operation_timeout, - batch_attempt_timeout=expected_attempt_timeout, - ) as instance: - assert instance._operation_timeout == expected_operation_timeout - assert instance._attempt_timeout == expected_attempt_timeout - instance._execute_mutate_rows([self._make_mutation()]) - assert mutate_rows.call_count == 1 - kwargs = mutate_rows.call_args[1] - assert kwargs["operation_timeout"] == expected_operation_timeout - assert kwargs["attempt_timeout"] == expected_attempt_timeout - - @pytest.mark.parametrize( - "limit,in_e,start_e,end_e", - [ - (10, 0, (10, 0), (10, 0)), - (1, 10, (0, 0), (1, 1)), - (10, 1, (0, 0), (1, 0)), - (10, 10, (0, 0), (10, 0)), - (10, 11, (0, 0), (10, 1)), - (3, 20, (0, 0), (3, 3)), - (10, 20, (0, 0), (10, 10)), - (10, 21, (0, 0), (10, 10)), - (2, 1, (2, 0), (2, 1)), - (2, 1, (1, 0), (2, 0)), - (2, 2, (1, 0), (2, 1)), - (3, 1, (3, 1), (3, 2)), - (3, 3, (3, 1), (3, 3)), - (1000, 5, (999, 0), (1000, 4)), - (1000, 5, (0, 0), (5, 0)), - (1000, 5, (1000, 0), (1000, 5)), - ], - ) - def test__add_exceptions(self, limit, in_e, start_e, end_e): - """ - Test that the _add_exceptions function properly updates the - _oldest_exceptions and _newest_exceptions lists - Args: - - limit: the _exception_list_limit representing the max size of either list - - in_e: size of list of exceptions to send to _add_exceptions - - start_e: a tuple of ints representing the initial sizes of _oldest_exceptions and _newest_exceptions - - end_e: a tuple of ints representing the expected sizes of _oldest_exceptions and _newest_exceptions - """ - from collections import deque - - input_list = [RuntimeError(f"mock {i}") for i in range(in_e)] - mock_batcher = mock.Mock() - mock_batcher._oldest_exceptions = [ - RuntimeError(f"starting mock {i}") for i in range(start_e[0]) - ] - mock_batcher._newest_exceptions = deque( - [RuntimeError(f"starting mock {i}") for i in range(start_e[1])], - maxlen=limit, - ) - mock_batcher._exception_list_limit = limit - mock_batcher._exceptions_since_last_raise = 0 - self._get_target_class()._add_exceptions(mock_batcher, input_list) - assert len(mock_batcher._oldest_exceptions) == end_e[0] - assert len(mock_batcher._newest_exceptions) == end_e[1] - assert mock_batcher._exceptions_since_last_raise == in_e - oldest_list_diff = end_e[0] - start_e[0] - newest_list_diff = min(max(in_e - oldest_list_diff, 0), limit) - for i in range(oldest_list_diff): - assert mock_batcher._oldest_exceptions[i + start_e[0]] == input_list[i] - for i in range(1, newest_list_diff + 1): - assert mock_batcher._newest_exceptions[-i] == input_list[-i] - - @pytest.mark.parametrize( - "input_retryables,expected_retryables", - [ - ( - TABLE_DEFAULT.READ_ROWS, - [ - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - core_exceptions.Aborted, - ], - ), - ( - TABLE_DEFAULT.DEFAULT, - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ), - ( - TABLE_DEFAULT.MUTATE_ROWS, - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ), - ([], []), - ([4], [core_exceptions.DeadlineExceeded]), - ], - ) - def test_customizable_retryable_errors(self, input_retryables, expected_retryables): - """ - Test that retryable functions support user-configurable arguments, and that the configured retryables are passed - down to the gapic layer. - """ - retryn_fn = ( - "retry_target_async" - if "Async" in self._get_target_class().__name__ - else "retry_target" - ) - with mock.patch.object( - google.api_core.retry, "if_exception_type" - ) as predicate_builder_mock: - with mock.patch.object(google.api_core.retry, retryn_fn) as retry_fn_mock: - table = None - with mock.patch("asyncio.create_task"): - table = Table(mock.Mock(), "instance", "table") - with self._make_one( - table, batch_retryable_errors=input_retryables - ) as instance: - assert instance._retryable_errors == expected_retryables - expected_predicate = lambda a: a in expected_retryables - predicate_builder_mock.return_value = expected_predicate - retry_fn_mock.side_effect = RuntimeError("stop early") - mutation = self._make_mutation(count=1, size=1) - instance._execute_mutate_rows([mutation]) - predicate_builder_mock.assert_called_once_with( - *expected_retryables, _MutateRowsIncomplete - ) - retry_call_args = retry_fn_mock.call_args_list[0].args - assert retry_call_args[1] is expected_predicate - - -class TestBigtableDataClient(ABC): - @staticmethod - def _get_target_class(): - from google.cloud.bigtable.data._sync.client import BigtableDataClient - - return BigtableDataClient - - @classmethod - def _make_client(cls, *args, use_emulator=True, **kwargs): - import os - - env_mask = {} - if use_emulator: - env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" - import warnings - - warnings.filterwarnings("ignore", category=RuntimeWarning) - else: - kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) - kwargs["project"] = kwargs.get("project", "project-id") - with mock.patch.dict(os.environ, env_mask): - return cls._get_target_class()(*args, **kwargs) - - @property - def is_async(self): - return False - - def test_ctor(self): - expected_project = "project-id" - expected_pool_size = 11 - expected_credentials = AnonymousCredentials() - client = self._make_client( - project="project-id", - pool_size=expected_pool_size, - credentials=expected_credentials, - use_emulator=False, - ) - time.sleep(0) - assert client.project == expected_project - assert len(client.transport._grpc_channel._pool) == expected_pool_size - assert not client._active_instances - assert len(client._channel_refresh_tasks) == expected_pool_size - assert client.transport._credentials == expected_credentials - client.close() - - def test_ctor_super_inits(self): - from google.cloud.client import ClientWithProject - from google.api_core import client_options as client_options_lib - from google.cloud.bigtable import __version__ as bigtable_version - - project = "project-id" - pool_size = 11 - credentials = AnonymousCredentials() - client_options = {"api_endpoint": "foo.bar:1234"} - options_parsed = client_options_lib.from_dict(client_options) - asyncio_portion = "-async" if self.is_async else "" - transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" - with mock.patch.object(BigtableClient, "__init__") as bigtable_client_init: - bigtable_client_init.return_value = None - with mock.patch.object( - ClientWithProject, "__init__" - ) as client_project_init: - client_project_init.return_value = None - try: - self._make_client( - project=project, - pool_size=pool_size, - credentials=credentials, - client_options=options_parsed, - use_emulator=False, - ) - except AttributeError: - pass - assert bigtable_client_init.call_count == 1 - kwargs = bigtable_client_init.call_args[1] - assert kwargs["transport"] == transport_str - assert kwargs["credentials"] == credentials - assert kwargs["client_options"] == options_parsed - assert client_project_init.call_count == 1 - kwargs = client_project_init.call_args[1] - assert kwargs["project"] == project - assert kwargs["credentials"] == credentials - assert kwargs["client_options"] == options_parsed - - def test_ctor_dict_options(self): - from google.api_core.client_options import ClientOptions - - client_options = {"api_endpoint": "foo.bar:1234"} - with mock.patch.object(BigtableClient, "__init__") as bigtable_client_init: - try: - self._make_client(client_options=client_options) - except TypeError: - pass - bigtable_client_init.assert_called_once() - kwargs = bigtable_client_init.call_args[1] - called_options = kwargs["client_options"] - assert called_options.api_endpoint == "foo.bar:1234" - assert isinstance(called_options, ClientOptions) - with mock.patch.object( - self._get_target_class(), "_start_background_channel_refresh" - ) as start_background_refresh: - client = self._make_client( - client_options=client_options, use_emulator=False - ) - start_background_refresh.assert_called_once() - client.close() - - def test_veneer_grpc_headers(self): - client_component = "data-async" if self.is_async else "data" - VENEER_HEADER_REGEX = re.compile( - "gapic\\/[0-9]+\\.[\\w.-]+ gax\\/[0-9]+\\.[\\w.-]+ gccl\\/[0-9]+\\.[\\w.-]+-" - + client_component - + " gl-python\\/[0-9]+\\.[\\w.-]+ grpc\\/[0-9]+\\.[\\w.-]+" - ) - if self.is_async: - patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") - else: - patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") - with patch as gapic_mock: - client = self._make_client(project="project-id") - wrapped_call_list = gapic_mock.call_args_list - assert len(wrapped_call_list) > 0 - for call in wrapped_call_list: - client_info = call.kwargs["client_info"] - assert client_info is not None, f"{call} has no client_info" - wrapped_user_agent_sorted = " ".join( - sorted(client_info.to_user_agent().split(" ")) - ) - assert VENEER_HEADER_REGEX.match( - wrapped_user_agent_sorted - ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" - client.close() - - def test_channel_pool_creation(self): - pool_size = 14 - with mock.patch.object( - grpc_helpers, "create_channel", mock.Mock() - ) as create_channel: - client = self._make_client(project="project-id", pool_size=pool_size) - assert create_channel.call_count == pool_size - client.close() - client = self._make_client(project="project-id", pool_size=pool_size) - pool_list = list(client.transport._grpc_channel._pool) - pool_set = set(client.transport._grpc_channel._pool) - assert len(pool_list) == len(pool_set) - client.close() - - def test_channel_pool_rotation(self): - pool_size = 7 - with mock.patch.object(PooledChannel, "next_channel") as next_channel: - client = self._make_client(project="project-id", pool_size=pool_size) - assert len(client.transport._grpc_channel._pool) == pool_size - next_channel.reset_mock() - with mock.patch.object( - type(client.transport._grpc_channel._pool[0]), "unary_unary" - ) as unary_unary: - channel_next = None - for i in range(pool_size): - channel_last = channel_next - channel_next = client.transport.grpc_channel._pool[i] - assert channel_last != channel_next - next_channel.return_value = channel_next - client.transport.ping_and_warm() - assert next_channel.call_count == i + 1 - unary_unary.assert_called_once() - unary_unary.reset_mock() - client.close() - - def test_channel_pool_replace(self): - import time - - sleep_module = asyncio if self.is_async else time - with mock.patch.object(sleep_module, "sleep"): - pool_size = 7 - client = self._make_client(project="project-id", pool_size=pool_size) - for replace_idx in range(pool_size): - start_pool = [ - channel for channel in client.transport._grpc_channel._pool - ] - grace_period = 9 - with mock.patch.object( - type(client.transport._grpc_channel._pool[-1]), "close" - ) as close: - new_channel = client.transport.create_channel() - client.transport.replace_channel( - replace_idx, grace=grace_period, new_channel=new_channel - ) - close.assert_called_once() - if self.is_async: - close.assert_called_once_with(grace=grace_period) - close.assert_called_once() - assert client.transport._grpc_channel._pool[replace_idx] == new_channel - for i in range(pool_size): - if i != replace_idx: - assert client.transport._grpc_channel._pool[i] == start_pool[i] - else: - assert client.transport._grpc_channel._pool[i] != start_pool[i] - client.close() - - def test__start_background_channel_refresh_tasks_exist(self): - client = self._make_client(project="project-id", use_emulator=False) - assert len(client._channel_refresh_tasks) > 0 - with mock.patch.object(asyncio, "create_task") as create_task: - client._start_background_channel_refresh() - create_task.assert_not_called() - client.close() - - @pytest.mark.parametrize("pool_size", [1, 3, 7]) - def test__start_background_channel_refresh(self, pool_size): - import concurrent.futures - - with mock.patch.object( - self._get_target_class(), "_ping_and_warm_instances", mock.Mock() - ) as ping_and_warm: - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - client._start_background_channel_refresh() - assert len(client._channel_refresh_tasks) == pool_size - for task in client._channel_refresh_tasks: - if self.is_async: - assert isinstance(task, asyncio.Task) - else: - assert isinstance(task, concurrent.futures.Future) - time.sleep(0.1) - assert ping_and_warm.call_count == pool_size - for channel in client.transport._grpc_channel._pool: - ping_and_warm.assert_any_call(channel) - client.close() - - def test__ping_and_warm_instances(self): - """test ping and warm with mocked asyncio.gather""" - client_mock = mock.Mock() - client_mock._execute_ping_and_warms = ( - lambda *args: self._get_target_class()._execute_ping_and_warms( - client_mock, *args - ) - ) - gather_tuple = ( - (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") - ) - with mock.patch.object(*gather_tuple, mock.Mock()) as gather: - if self.is_async: - gather.side_effect = lambda *args, **kwargs: [None for _ in args] - else: - gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] - channel = mock.Mock() - client_mock._active_instances = [] - result = self._get_target_class()._ping_and_warm_instances( - client_mock, channel - ) - assert len(result) == 0 - if self.is_async: - assert gather.call_args.kwargs == {"return_exceptions": True} - client_mock._active_instances = [ - (mock.Mock(), mock.Mock(), mock.Mock()) - ] * 4 - gather.reset_mock() - channel.reset_mock() - result = self._get_target_class()._ping_and_warm_instances( - client_mock, channel - ) - assert len(result) == 4 - if self.is_async: - gather.assert_called_once() - gather.assert_called_once() - assert len(gather.call_args.args) == 4 - else: - assert gather.call_count == 4 - grpc_call_args = channel.unary_unary().call_args_list - for idx, (_, kwargs) in enumerate(grpc_call_args): - ( - expected_instance, - expected_table, - expected_app_profile, - ) = client_mock._active_instances[idx] - request = kwargs["request"] - assert request["name"] == expected_instance - assert request["app_profile_id"] == expected_app_profile - metadata = kwargs["metadata"] - assert len(metadata) == 1 - assert metadata[0][0] == "x-goog-request-params" - assert ( - metadata[0][1] - == f"name={expected_instance}&app_profile_id={expected_app_profile}" - ) - - def test__ping_and_warm_single_instance(self): - """should be able to call ping and warm with single instance""" - client_mock = mock.Mock() - client_mock._execute_ping_and_warms = ( - lambda *args: self._get_target_class()._execute_ping_and_warms( - client_mock, *args - ) - ) - gather_tuple = ( - (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") - ) - with mock.patch.object(*gather_tuple, mock.Mock()) as gather: - gather.side_effect = lambda *args, **kwargs: [mock.Mock() for _ in args] - if self.is_async: - gather.side_effect = lambda *args, **kwargs: [None for _ in args] - else: - gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] - channel = mock.Mock() - client_mock._active_instances = [mock.Mock()] * 100 - test_key = ("test-instance", "test-table", "test-app-profile") - result = self._get_target_class()._ping_and_warm_instances( - client_mock, channel, test_key - ) - assert len(result) == 1 - grpc_call_args = channel.unary_unary().call_args_list - assert len(grpc_call_args) == 1 - kwargs = grpc_call_args[0][1] - request = kwargs["request"] - assert request["name"] == "test-instance" - assert request["app_profile_id"] == "test-app-profile" - metadata = kwargs["metadata"] - assert len(metadata) == 1 - assert metadata[0][0] == "x-goog-request-params" - assert ( - metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" - ) - - @pytest.mark.parametrize( - "refresh_interval, wait_time, expected_sleep", - [(0, 0, 0), (0, 1, 0), (10, 0, 10), (10, 5, 5), (10, 10, 0), (10, 15, 0)], - ) - def test__manage_channel_first_sleep( - self, refresh_interval, wait_time, expected_sleep - ): - import threading - import time - - with mock.patch.object(time, "monotonic") as monotonic: - monotonic.return_value = 0 - sleep_tuple = ( - (asyncio, "sleep") if self.is_async else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: - sleep.side_effect = asyncio.CancelledError - try: - client = self._make_client(project="project-id") - client._channel_init_time = -wait_time - client._manage_channel(0, refresh_interval, refresh_interval) - except asyncio.CancelledError: - pass - sleep.assert_called_once() - call_time = sleep.call_args[0][0] - assert ( - abs(call_time - expected_sleep) < 0.1 - ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" - client.close() - - def test__manage_channel_ping_and_warm(self): - """_manage channel should call ping and warm internally""" - import time - import threading - - client_mock = mock.Mock() - client_mock._is_closed.is_set.return_value = False - client_mock._channel_init_time = time.monotonic() - channel_list = [mock.Mock(), mock.Mock()] - client_mock.transport.channels = channel_list - new_channel = mock.Mock() - client_mock.transport.grpc_channel._create_channel.return_value = new_channel - sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") - with mock.patch.object(*sleep_tuple): - client_mock.transport.replace_channel.side_effect = asyncio.CancelledError - ping_and_warm = client_mock._ping_and_warm_instances = mock.Mock() - try: - channel_idx = 1 - self._get_target_class()._manage_channel(client_mock, channel_idx, 10) - except asyncio.CancelledError: - pass - assert ping_and_warm.call_count == 2 - assert client_mock.transport.replace_channel.call_count == 1 - old_channel = channel_list[channel_idx] - assert old_channel != new_channel - called_with = [call[0][0] for call in ping_and_warm.call_args_list] - assert old_channel in called_with - assert new_channel in called_with - ping_and_warm.reset_mock() - try: - self._get_target_class()._manage_channel(client_mock, 0, 0, 0) - except asyncio.CancelledError: - pass - ping_and_warm.assert_called_once_with(new_channel) - - @pytest.mark.parametrize( - "refresh_interval, num_cycles, expected_sleep", - [(None, 1, 60 * 35), (10, 10, 100), (10, 1, 10)], - ) - def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sleep): - import time - import random - import threading - - channel_idx = 1 - with mock.patch.object(random, "uniform") as uniform: - uniform.side_effect = lambda min_, max_: min_ - with mock.patch.object(time, "time") as time_mock: - time_mock.return_value = 0 - sleep_tuple = ( - (asyncio, "sleep") if self.is_async else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: - sleep.side_effect = [None for i in range(num_cycles - 1)] + [ - asyncio.CancelledError - ] - client = self._make_client(project="project-id") - with mock.patch.object(client.transport, "replace_channel"): - try: - if refresh_interval is not None: - client._manage_channel( - channel_idx, refresh_interval, refresh_interval - ) - else: - client._manage_channel(channel_idx) - except asyncio.CancelledError: - pass - assert sleep.call_count == num_cycles - total_sleep = sum([call[0][0] for call in sleep.call_args_list]) - assert ( - abs(total_sleep - expected_sleep) < 0.1 - ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" - client.close() - - def test__manage_channel_random(self): - import random - import threading - - sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") - with mock.patch.object(*sleep_tuple) as sleep: - with mock.patch.object(random, "uniform") as uniform: - uniform.return_value = 0 - try: - uniform.side_effect = asyncio.CancelledError - client = self._make_client(project="project-id", pool_size=1) - except asyncio.CancelledError: - uniform.side_effect = None - uniform.reset_mock() - sleep.reset_mock() - min_val = 200 - max_val = 205 - uniform.side_effect = lambda min_, max_: min_ - sleep.side_effect = [None, None, asyncio.CancelledError] - try: - with mock.patch.object(client.transport, "replace_channel"): - client._manage_channel(0, min_val, max_val) - except asyncio.CancelledError: - pass - assert uniform.call_count == 3 - uniform_args = [call[0] for call in uniform.call_args_list] - for found_min, found_max in uniform_args: - assert found_min == min_val - assert found_max == max_val - - @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) - def test__manage_channel_refresh(self, num_cycles): - import threading - - expected_grace = 9 - expected_refresh = 0.5 - channel_idx = 1 - grpc_lib = grpc.aio if self.is_async else grpc - new_channel = grpc_lib.insecure_channel("localhost:8080") - with mock.patch.object( - PooledBigtableGrpcTransport, "replace_channel" - ) as replace_channel: - sleep_tuple = ( - (asyncio, "sleep") if self.is_async else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: - sleep.side_effect = [None for i in range(num_cycles)] + [ - asyncio.CancelledError - ] - with mock.patch.object( - grpc_helpers, "create_channel" - ) as create_channel: - create_channel.return_value = new_channel - with mock.patch.object( - self._get_target_class(), "_start_background_channel_refresh" - ): - client = self._make_client( - project="project-id", use_emulator=False - ) - create_channel.reset_mock() - try: - client._manage_channel( - channel_idx, - refresh_interval_min=expected_refresh, - refresh_interval_max=expected_refresh, - grace_period=expected_grace, - ) - except asyncio.CancelledError: - pass - assert sleep.call_count == num_cycles + 1 - assert create_channel.call_count == num_cycles - assert replace_channel.call_count == num_cycles - for call in replace_channel.call_args_list: - (args, kwargs) = call - assert args[0] == channel_idx - assert kwargs["grace"] == expected_grace - assert kwargs["new_channel"] == new_channel - client.close() - - def test__register_instance(self): - """test instance registration""" - client_mock = mock.Mock() - client_mock._gapic_client.instance_path.side_effect = lambda a, b: f"prefix/{b}" - active_instances = set() - instance_owners = {} - client_mock._active_instances = active_instances - client_mock._instance_owners = instance_owners - client_mock._channel_refresh_tasks = [] - client_mock._start_background_channel_refresh.side_effect = ( - lambda: client_mock._channel_refresh_tasks.append(mock.Mock) - ) - mock_channels = [mock.Mock() for i in range(5)] - client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = mock.Mock() - table_mock = mock.Mock() - self._get_target_class()._register_instance( - client_mock, "instance-1", table_mock - ) - assert client_mock._start_background_channel_refresh.call_count == 1 - expected_key = ( - "prefix/instance-1", - table_mock.table_name, - table_mock.app_profile_id, - ) - assert len(active_instances) == 1 - assert expected_key == tuple(list(active_instances)[0]) - assert len(instance_owners) == 1 - assert expected_key == tuple(list(instance_owners)[0]) - assert client_mock._channel_refresh_tasks - table_mock2 = mock.Mock() - self._get_target_class()._register_instance( - client_mock, "instance-2", table_mock2 - ) - assert client_mock._start_background_channel_refresh.call_count == 1 - assert client_mock._ping_and_warm_instances.call_count == len(mock_channels) - for channel in mock_channels: - assert channel in [ - call[0][0] - for call in client_mock._ping_and_warm_instances.call_args_list - ] - assert len(active_instances) == 2 - assert len(instance_owners) == 2 - expected_key2 = ( - "prefix/instance-2", - table_mock2.table_name, - table_mock2.app_profile_id, - ) - assert any( - [ - expected_key2 == tuple(list(active_instances)[i]) - for i in range(len(active_instances)) - ] - ) - assert any( - [ - expected_key2 == tuple(list(instance_owners)[i]) - for i in range(len(instance_owners)) - ] - ) - - @pytest.mark.parametrize( - "insert_instances,expected_active,expected_owner_keys", - [ - ([("i", "t", None)], [("i", "t", None)], [("i", "t", None)]), - ([("i", "t", "p")], [("i", "t", "p")], [("i", "t", "p")]), - ([("1", "t", "p"), ("1", "t", "p")], [("1", "t", "p")], [("1", "t", "p")]), - ( - [("1", "t", "p"), ("2", "t", "p")], - [("1", "t", "p"), ("2", "t", "p")], - [("1", "t", "p"), ("2", "t", "p")], - ), - ], - ) - def test__register_instance_state( - self, insert_instances, expected_active, expected_owner_keys - ): - """test that active_instances and instance_owners are updated as expected""" - client_mock = mock.Mock() - client_mock._gapic_client.instance_path.side_effect = lambda a, b: b - active_instances = set() - instance_owners = {} - client_mock._active_instances = active_instances - client_mock._instance_owners = instance_owners - client_mock._channel_refresh_tasks = [] - client_mock._start_background_channel_refresh.side_effect = ( - lambda: client_mock._channel_refresh_tasks.append(mock.Mock) - ) - mock_channels = [mock.Mock() for i in range(5)] - client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = mock.Mock() - table_mock = mock.Mock() - for instance, table, profile in insert_instances: - table_mock.table_name = table - table_mock.app_profile_id = profile - self._get_target_class()._register_instance( - client_mock, instance, table_mock - ) - assert len(active_instances) == len(expected_active) - assert len(instance_owners) == len(expected_owner_keys) - for expected in expected_active: - assert any( - [ - expected == tuple(list(active_instances)[i]) - for i in range(len(active_instances)) - ] - ) - for expected in expected_owner_keys: - assert any( - [ - expected == tuple(list(instance_owners)[i]) - for i in range(len(instance_owners)) - ] - ) - - def test__remove_instance_registration(self): - client = self._make_client(project="project-id") - table = mock.Mock() - client._register_instance("instance-1", table) - client._register_instance("instance-2", table) - assert len(client._active_instances) == 2 - assert len(client._instance_owners.keys()) == 2 - instance_1_path = client._gapic_client.instance_path( - client.project, "instance-1" - ) - instance_1_key = (instance_1_path, table.table_name, table.app_profile_id) - instance_2_path = client._gapic_client.instance_path( - client.project, "instance-2" - ) - instance_2_key = (instance_2_path, table.table_name, table.app_profile_id) - assert len(client._instance_owners[instance_1_key]) == 1 - assert list(client._instance_owners[instance_1_key])[0] == id(table) - assert len(client._instance_owners[instance_2_key]) == 1 - assert list(client._instance_owners[instance_2_key])[0] == id(table) - success = client._remove_instance_registration("instance-1", table) - assert success - assert len(client._active_instances) == 1 - assert len(client._instance_owners[instance_1_key]) == 0 - assert len(client._instance_owners[instance_2_key]) == 1 - assert client._active_instances == {instance_2_key} - success = client._remove_instance_registration("fake-key", table) - assert not success - assert len(client._active_instances) == 1 - client.close() - - def test__multiple_table_registration(self): - """ - registering with multiple tables with the same key should - add multiple owners to instance_owners, but only keep one copy - of shared key in active_instances - """ - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - - with self._make_client(project="project-id") as client: - with client.get_table("instance_1", "table_1") as table_1: - instance_1_path = client._gapic_client.instance_path( - client.project, "instance_1" - ) - instance_1_key = _WarmedInstanceKey( - instance_1_path, table_1.table_name, table_1.app_profile_id - ) - assert len(client._instance_owners[instance_1_key]) == 1 - assert len(client._active_instances) == 1 - assert id(table_1) in client._instance_owners[instance_1_key] - with client.get_table("instance_1", "table_1") as table_2: - assert len(client._instance_owners[instance_1_key]) == 2 - assert len(client._active_instances) == 1 - assert id(table_1) in client._instance_owners[instance_1_key] - assert id(table_2) in client._instance_owners[instance_1_key] - with client.get_table("instance_1", "table_3") as table_3: - instance_3_path = client._gapic_client.instance_path( - client.project, "instance_1" - ) - instance_3_key = _WarmedInstanceKey( - instance_3_path, table_3.table_name, table_3.app_profile_id - ) - assert len(client._instance_owners[instance_1_key]) == 2 - assert len(client._instance_owners[instance_3_key]) == 1 - assert len(client._active_instances) == 2 - assert id(table_1) in client._instance_owners[instance_1_key] - assert id(table_2) in client._instance_owners[instance_1_key] - assert id(table_3) in client._instance_owners[instance_3_key] - assert len(client._active_instances) == 1 - assert instance_1_key in client._active_instances - assert id(table_2) not in client._instance_owners[instance_1_key] - assert len(client._active_instances) == 0 - assert instance_1_key not in client._active_instances - assert len(client._instance_owners[instance_1_key]) == 0 - - def test__multiple_instance_registration(self): - """ - registering with multiple instance keys should update the key - in instance_owners and active_instances - """ - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - - with self._make_client(project="project-id") as client: - with client.get_table("instance_1", "table_1") as table_1: - with client.get_table("instance_2", "table_2") as table_2: - instance_1_path = client._gapic_client.instance_path( - client.project, "instance_1" - ) - instance_1_key = _WarmedInstanceKey( - instance_1_path, table_1.table_name, table_1.app_profile_id - ) - instance_2_path = client._gapic_client.instance_path( - client.project, "instance_2" - ) - instance_2_key = _WarmedInstanceKey( - instance_2_path, table_2.table_name, table_2.app_profile_id - ) - assert len(client._instance_owners[instance_1_key]) == 1 - assert len(client._instance_owners[instance_2_key]) == 1 - assert len(client._active_instances) == 2 - assert id(table_1) in client._instance_owners[instance_1_key] - assert id(table_2) in client._instance_owners[instance_2_key] - assert len(client._active_instances) == 1 - assert instance_1_key in client._active_instances - assert len(client._instance_owners[instance_2_key]) == 0 - assert len(client._instance_owners[instance_1_key]) == 1 - assert id(table_1) in client._instance_owners[instance_1_key] - assert len(client._active_instances) == 0 - assert len(client._instance_owners[instance_1_key]) == 0 - assert len(client._instance_owners[instance_2_key]) == 0 - - def test_get_table(self): - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - - client = self._make_client(project="project-id") - assert not client._active_instances - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - table = client.get_table( - expected_instance_id, expected_table_id, expected_app_profile_id - ) - time.sleep(0) - assert isinstance(table, TestTable._get_target_class()) - assert table.table_id == expected_table_id - assert ( - table.table_name - == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" - ) - assert table.instance_id == expected_instance_id - assert ( - table.instance_name - == f"projects/{client.project}/instances/{expected_instance_id}" - ) - assert table.app_profile_id == expected_app_profile_id - assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id - ) - assert instance_key in client._active_instances - assert client._instance_owners[instance_key] == {id(table)} - client.close() - - def test_get_table_arg_passthrough(self): - """All arguments passed in get_table should be sent to constructor""" - with self._make_client(project="project-id") as client: - with mock.patch.object( - TestTable._get_target_class(), "__init__" - ) as mock_constructor: - mock_constructor.return_value = None - assert not client._active_instances - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - expected_args = (1, "test", {"test": 2}) - expected_kwargs = {"hello": "world", "test": 2} - client.get_table( - expected_instance_id, - expected_table_id, - expected_app_profile_id, - *expected_args, - **expected_kwargs, - ) - mock_constructor.assert_called_once_with( - client, - expected_instance_id, - expected_table_id, - expected_app_profile_id, - *expected_args, - **expected_kwargs, - ) - - def test_get_table_context_manager(self): - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - expected_project_id = "project-id" - with mock.patch.object(TestTable._get_target_class(), "close") as close_mock: - with self._make_client(project=expected_project_id) as client: - with client.get_table( - expected_instance_id, expected_table_id, expected_app_profile_id - ) as table: - time.sleep(0) - assert isinstance(table, TestTable._get_target_class()) - assert table.table_id == expected_table_id - assert ( - table.table_name - == f"projects/{expected_project_id}/instances/{expected_instance_id}/tables/{expected_table_id}" - ) - assert table.instance_id == expected_instance_id - assert ( - table.instance_name - == f"projects/{expected_project_id}/instances/{expected_instance_id}" - ) - assert table.app_profile_id == expected_app_profile_id - assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id - ) - assert instance_key in client._active_instances - assert client._instance_owners[instance_key] == {id(table)} - assert close_mock.call_count == 1 - - def test_multiple_pool_sizes(self): - pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] - for pool_size in pool_sizes: - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - assert len(client._channel_refresh_tasks) == pool_size - client_duplicate = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - assert len(client_duplicate._channel_refresh_tasks) == pool_size - assert str(pool_size) in str(client.transport) - client.close() - client_duplicate.close() - - def test_close(self): - pool_size = 7 - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - assert len(client._channel_refresh_tasks) == pool_size - tasks_list = list(client._channel_refresh_tasks) - for task in client._channel_refresh_tasks: - assert not task.done() - with mock.patch.object( - PooledBigtableGrpcTransport, "close", mock.Mock() - ) as close_mock: - client.close() - close_mock.assert_called_once() - close_mock.assert_called_once() - for task in tasks_list: - assert task.done() - assert client._channel_refresh_tasks == [] - - def test_context_manager(self): - close_mock = mock.Mock() - true_close = None - with self._make_client(project="project-id") as client: - true_close = client.close() - client.close = close_mock - for task in client._channel_refresh_tasks: - assert not task.done() - assert client.project == "project-id" - assert client._active_instances == set() - close_mock.assert_not_called() - close_mock.assert_called_once() - close_mock.assert_called_once() - true_close - - -class TestTable(ABC): - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - @staticmethod - def _get_target_class(): - from google.cloud.bigtable.data._sync.client import Table - - return Table - - @property - def is_async(self): - return False - - def test_table_ctor(self): - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - expected_operation_timeout = 123 - expected_attempt_timeout = 12 - expected_read_rows_operation_timeout = 1.5 - expected_read_rows_attempt_timeout = 0.5 - expected_mutate_rows_operation_timeout = 2.5 - expected_mutate_rows_attempt_timeout = 0.75 - client = self._make_client() - assert not client._active_instances - table = self._get_target_class()( - client, - expected_instance_id, - expected_table_id, - expected_app_profile_id, - default_operation_timeout=expected_operation_timeout, - default_attempt_timeout=expected_attempt_timeout, - default_read_rows_operation_timeout=expected_read_rows_operation_timeout, - default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout, - default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, - default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, - ) - time.sleep(0) - assert table.table_id == expected_table_id - assert table.instance_id == expected_instance_id - assert table.app_profile_id == expected_app_profile_id - assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id - ) - assert instance_key in client._active_instances - assert client._instance_owners[instance_key] == {id(table)} - assert table.default_operation_timeout == expected_operation_timeout - assert table.default_attempt_timeout == expected_attempt_timeout - assert ( - table.default_read_rows_operation_timeout - == expected_read_rows_operation_timeout - ) - assert ( - table.default_read_rows_attempt_timeout - == expected_read_rows_attempt_timeout - ) - assert ( - table.default_mutate_rows_operation_timeout - == expected_mutate_rows_operation_timeout - ) - assert ( - table.default_mutate_rows_attempt_timeout - == expected_mutate_rows_attempt_timeout - ) - table._register_instance_future - assert table._register_instance_future.done() - assert not table._register_instance_future.cancelled() - assert table._register_instance_future.exception() is None - client.close() - - def test_table_ctor_defaults(self): - """should provide default timeout values and app_profile_id""" - expected_table_id = "table-id" - expected_instance_id = "instance-id" - client = self._make_client() - assert not client._active_instances - table = Table(client, expected_instance_id, expected_table_id) - time.sleep(0) - assert table.table_id == expected_table_id - assert table.instance_id == expected_instance_id - assert table.app_profile_id is None - assert table.client is client - assert table.default_operation_timeout == 60 - assert table.default_read_rows_operation_timeout == 600 - assert table.default_mutate_rows_operation_timeout == 600 - assert table.default_attempt_timeout == 20 - assert table.default_read_rows_attempt_timeout == 20 - assert table.default_mutate_rows_attempt_timeout == 60 - client.close() - - def test_table_ctor_invalid_timeout_values(self): - """bad timeout values should raise ValueError""" - client = self._make_client() - timeout_pairs = [ - ("default_operation_timeout", "default_attempt_timeout"), - ( - "default_read_rows_operation_timeout", - "default_read_rows_attempt_timeout", - ), - ( - "default_mutate_rows_operation_timeout", - "default_mutate_rows_attempt_timeout", - ), - ] - for operation_timeout, attempt_timeout in timeout_pairs: - with pytest.raises(ValueError) as e: - Table(client, "", "", **{attempt_timeout: -1}) - assert "attempt_timeout must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - Table(client, "", "", **{operation_timeout: -1}) - assert "operation_timeout must be greater than 0" in str(e.value) - client.close() - - @pytest.mark.parametrize( - "fn_name,fn_args,is_stream,extra_retryables", - [ - ("read_rows_stream", (ReadRowsQuery(),), True, ()), - ("read_rows", (ReadRowsQuery(),), True, ()), - ("read_row", (b"row_key",), True, ()), - ("read_rows_sharded", ([ReadRowsQuery()],), True, ()), - ("row_exists", (b"row_key",), True, ()), - ("sample_row_keys", (), False, ()), - ("mutate_row", (b"row_key", [mock.Mock()]), False, ()), - ( - "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), - False, - (_MutateRowsIncomplete,), - ), - ], - ) - @pytest.mark.parametrize( - "input_retryables,expected_retryables", - [ - ( - TABLE_DEFAULT.READ_ROWS, - [ - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - core_exceptions.Aborted, - ], - ), - ( - TABLE_DEFAULT.DEFAULT, - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ), - ( - TABLE_DEFAULT.MUTATE_ROWS, - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ), - ([], []), - ([4], [core_exceptions.DeadlineExceeded]), - ], - ) - def test_customizable_retryable_errors( - self, - input_retryables, - expected_retryables, - fn_name, - fn_args, - is_stream, - extra_retryables, - ): - """ - Test that retryable functions support user-configurable arguments, and that the configured retryables are passed - down to the gapic layer. - """ - retry_fn = "retry_target" - if is_stream: - retry_fn += "_stream" - if self.is_async: - retry_fn += "_async" - with mock.patch(f"google.api_core.retry.{retry_fn}") as retry_fn_mock: - with self._make_client() as client: - table = client.get_table("instance-id", "table-id") - expected_predicate = lambda a: a in expected_retryables - retry_fn_mock.side_effect = RuntimeError("stop early") - with mock.patch( - "google.api_core.retry.if_exception_type" - ) as predicate_builder_mock: - predicate_builder_mock.return_value = expected_predicate - with pytest.raises(Exception): - test_fn = table.__getattribute__(fn_name) - test_fn(*fn_args, retryable_errors=input_retryables) - predicate_builder_mock.assert_called_once_with( - *expected_retryables, *extra_retryables - ) - retry_call_args = retry_fn_mock.call_args_list[0].args - assert retry_call_args[1] is expected_predicate - - @pytest.mark.parametrize( - "fn_name,fn_args,gapic_fn", - [ - ("read_rows_stream", (ReadRowsQuery(),), "read_rows"), - ("read_rows", (ReadRowsQuery(),), "read_rows"), - ("read_row", (b"row_key",), "read_rows"), - ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), - ("row_exists", (b"row_key",), "read_rows"), - ("sample_row_keys", (), "sample_row_keys"), - ("mutate_row", (b"row_key", [mock.Mock()]), "mutate_row"), - ( - "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), - "mutate_rows", - ), - ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), - ( - "read_modify_write_row", - (b"row_key", mock.Mock()), - "read_modify_write_row", - ), - ], - ) - @pytest.mark.parametrize("include_app_profile", [True, False]) - def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): - """check that all requests attach proper metadata headers""" - profile = "profile" if include_app_profile else None - with mock.patch.object( - BigtableClient, gapic_fn, mock.mock.Mock() - ) as gapic_mock: - gapic_mock.side_effect = RuntimeError("stop early") - with self._make_client() as client: - table = Table(client, "instance-id", "table-id", profile) - try: - test_fn = table.__getattribute__(fn_name) - maybe_stream = test_fn(*fn_args) - [i for i in maybe_stream] - except Exception: - pass - kwargs = gapic_mock.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata - - -class TestReadRows(ABC): - """ - Tests for table.read_rows and related methods. - """ - - @staticmethod - def _get_operation_class(): - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - - return _ReadRowsOperation - - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - def _make_table(self, *args, **kwargs): - client_mock = mock.Mock() - client_mock._register_instance.side_effect = lambda *args, **kwargs: time.sleep( - 0 - ) - client_mock._remove_instance_registration.side_effect = ( - lambda *args, **kwargs: time.sleep(0) - ) - kwargs["instance_id"] = kwargs.get( - "instance_id", args[0] if args else "instance" - ) - kwargs["table_id"] = kwargs.get( - "table_id", args[1] if len(args) > 1 else "table" - ) - client_mock._gapic_client.table_path.return_value = kwargs["table_id"] - client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] - return Table(client_mock, *args, **kwargs) - - def _make_stats(self): - from google.cloud.bigtable_v2.types import RequestStats - from google.cloud.bigtable_v2.types import FullReadStatsView - from google.cloud.bigtable_v2.types import ReadIterationStats - - return RequestStats( - full_read_stats_view=FullReadStatsView( - read_iteration_stats=ReadIterationStats( - rows_seen_count=1, - rows_returned_count=2, - cells_seen_count=3, - cells_returned_count=4, - ) - ) - ) - - @staticmethod - def _make_chunk(*args, **kwargs): - from google.cloud.bigtable_v2 import ReadRowsResponse - - kwargs["row_key"] = kwargs.get("row_key", b"row_key") - kwargs["family_name"] = kwargs.get("family_name", "family_name") - kwargs["qualifier"] = kwargs.get("qualifier", b"qualifier") - kwargs["value"] = kwargs.get("value", b"value") - kwargs["commit_row"] = kwargs.get("commit_row", True) - return ReadRowsResponse.CellChunk(*args, **kwargs) - - @staticmethod - def _make_gapic_stream( - chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0 - ): - from google.cloud.bigtable_v2 import ReadRowsResponse - - class mock_stream: - def __init__(self, chunk_list, sleep_time): - self.chunk_list = chunk_list - self.idx = -1 - self.sleep_time = sleep_time - - def __iter__(self): - return self - - def __next__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - if sleep_time: - time.sleep(self.sleep_time) - chunk = self.chunk_list[self.idx] - if isinstance(chunk, Exception): - raise chunk - else: - return ReadRowsResponse(chunks=[chunk]) - raise StopIteration - - def cancel(self): - pass - - return mock_stream(chunk_list, sleep_time) - - def execute_fn(self, table, *args, **kwargs): - return table.read_rows(*args, **kwargs) - - def test_read_rows(self): - query = ReadRowsQuery() - chunks = [ - self._make_chunk(row_key=b"test_1"), - self._make_chunk(row_key=b"test_2"), - ] - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks - ) - results = self.execute_fn(table, query, operation_timeout=3) - assert len(results) == 2 - assert results[0].row_key == b"test_1" - assert results[1].row_key == b"test_2" - - def test_read_rows_stream(self): - query = ReadRowsQuery() - chunks = [ - self._make_chunk(row_key=b"test_1"), - self._make_chunk(row_key=b"test_2"), - ] - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks - ) - gen = table.read_rows_stream(query, operation_timeout=3) - results = [row for row in gen] - assert len(results) == 2 - assert results[0].row_key == b"test_1" - assert results[1].row_key == b"test_2" - - @pytest.mark.parametrize("include_app_profile", [True, False]) - def test_read_rows_query_matches_request(self, include_app_profile): - from google.cloud.bigtable.data import RowRange - from google.cloud.bigtable.data.row_filters import PassAllFilter - - app_profile_id = "app_profile_id" if include_app_profile else None - with self._make_table(app_profile_id=app_profile_id) as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream([]) - row_keys = [b"test_1", "test_2"] - row_ranges = RowRange("1start", "2end") - filter_ = PassAllFilter(True) - limit = 99 - query = ReadRowsQuery( - row_keys=row_keys, - row_ranges=row_ranges, - row_filter=filter_, - limit=limit, - ) - results = table.read_rows(query, operation_timeout=3) - assert len(results) == 0 - call_request = read_rows.call_args_list[0][0][0] - query_pb = query._to_pb(table) - assert call_request == query_pb - - @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) - def test_read_rows_timeout(self, operation_timeout): - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - query = ReadRowsQuery() - chunks = [self._make_chunk(row_key=b"test_1")] - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=1 - ) - try: - table.read_rows(query, operation_timeout=operation_timeout) - except core_exceptions.DeadlineExceeded as e: - assert ( - e.message - == f"operation_timeout of {operation_timeout:0.1f}s exceeded" - ) - - @pytest.mark.parametrize( - "per_request_t, operation_t, expected_num", - [(0.05, 0.08, 2), (0.05, 0.54, 11), (0.05, 0.14, 3), (0.05, 0.24, 5)], - ) - def test_read_rows_attempt_timeout(self, per_request_t, operation_t, expected_num): - """ - Ensures that the attempt_timeout is respected and that the number of - requests is as expected. - - operation_timeout does not cancel the request, so we expect the number of - requests to be the ceiling of operation_timeout / attempt_timeout. - """ - from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - - expected_last_timeout = operation_t - (expected_num - 1) * per_request_t - with mock.patch("random.uniform", side_effect=lambda a, b: 0): - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=per_request_t - ) - query = ReadRowsQuery() - chunks = [core_exceptions.DeadlineExceeded("mock deadline")] - try: - table.read_rows( - query, - operation_timeout=operation_t, - attempt_timeout=per_request_t, - ) - except core_exceptions.DeadlineExceeded as e: - retry_exc = e.__cause__ - if expected_num == 0: - assert retry_exc is None - else: - assert type(retry_exc) is RetryExceptionGroup - assert f"{expected_num} failed attempts" in str(retry_exc) - assert len(retry_exc.exceptions) == expected_num - for sub_exc in retry_exc.exceptions: - assert sub_exc.message == "mock deadline" - assert read_rows.call_count == expected_num - for _, call_kwargs in read_rows.call_args_list[:-1]: - assert call_kwargs["timeout"] == per_request_t - assert call_kwargs["retry"] is None - assert ( - abs( - read_rows.call_args_list[-1][1]["timeout"] - - expected_last_timeout - ) - < 0.05 - ) - - @pytest.mark.parametrize( - "exc_type", - [ - core_exceptions.Aborted, - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - ], - ) - def test_read_rows_retryable_error(self, exc_type): - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - [expected_error] - ) - query = ReadRowsQuery() - expected_error = exc_type("mock error") - try: - table.read_rows(query, operation_timeout=0.1) - except core_exceptions.DeadlineExceeded as e: - retry_exc = e.__cause__ - root_cause = retry_exc.exceptions[0] - assert type(root_cause) is exc_type - assert root_cause == expected_error - - @pytest.mark.parametrize( - "exc_type", - [ - core_exceptions.Cancelled, - core_exceptions.PreconditionFailed, - core_exceptions.NotFound, - core_exceptions.PermissionDenied, - core_exceptions.Conflict, - core_exceptions.InternalServerError, - core_exceptions.TooManyRequests, - core_exceptions.ResourceExhausted, - InvalidChunk, - ], - ) - def test_read_rows_non_retryable_error(self, exc_type): - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - [expected_error] - ) - query = ReadRowsQuery() - expected_error = exc_type("mock error") - try: - table.read_rows(query, operation_timeout=0.1) - except exc_type as e: - assert e == expected_error - - def test_read_rows_revise_request(self): - """Ensure that _revise_request is called between retries""" - from google.cloud.bigtable.data.exceptions import InvalidChunk - from google.cloud.bigtable_v2.types import RowSet - - return_val = RowSet() - with mock.patch.object( - self._get_operation_class(), "_revise_request_rowset" - ) as revise_rowset: - revise_rowset.return_value = return_val - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks - ) - row_keys = [b"test_1", b"test_2", b"test_3"] - query = ReadRowsQuery(row_keys=row_keys) - chunks = [ - self._make_chunk(row_key=b"test_1"), - core_exceptions.Aborted("mock retryable error"), - ] - try: - table.read_rows(query) - except InvalidChunk: - revise_rowset.assert_called() - first_call_kwargs = revise_rowset.call_args_list[0].kwargs - assert first_call_kwargs["row_set"] == query._to_pb(table).rows - assert first_call_kwargs["last_seen_row_key"] == b"test_1" - revised_call = read_rows.call_args_list[1].args[0] - assert revised_call.rows == return_val - - def test_read_rows_default_timeouts(self): - """Ensure that the default timeouts are set on the read rows operation when not overridden""" - operation_timeout = 8 - attempt_timeout = 4 - with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: - mock_op.side_effect = RuntimeError("mock error") - with self._make_table( - default_read_rows_operation_timeout=operation_timeout, - default_read_rows_attempt_timeout=attempt_timeout, - ) as table: - try: - table.read_rows(ReadRowsQuery()) - except RuntimeError: - pass - kwargs = mock_op.call_args_list[0].kwargs - assert kwargs["operation_timeout"] == operation_timeout - assert kwargs["attempt_timeout"] == attempt_timeout - - def test_read_rows_default_timeout_override(self): - """When timeouts are passed, they overwrite default values""" - operation_timeout = 8 - attempt_timeout = 4 - with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: - mock_op.side_effect = RuntimeError("mock error") - with self._make_table( - default_operation_timeout=99, default_attempt_timeout=97 - ) as table: - try: - table.read_rows( - ReadRowsQuery(), - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - ) - except RuntimeError: - pass - kwargs = mock_op.call_args_list[0].kwargs - assert kwargs["operation_timeout"] == operation_timeout - assert kwargs["attempt_timeout"] == attempt_timeout - - def test_read_row(self): - """Test reading a single row""" - with self._make_client() as client: - table = client.get_table("instance", "table") - row_key = b"test_1" - with mock.patch.object(table, "read_rows") as read_rows: - expected_result = object() - read_rows.side_effect = lambda *args, **kwargs: [expected_result] - expected_op_timeout = 8 - expected_req_timeout = 4 - row = table.read_row( - row_key, - operation_timeout=expected_op_timeout, - attempt_timeout=expected_req_timeout, - ) - assert row == expected_result - assert read_rows.call_count == 1 - (args, kwargs) = read_rows.call_args_list[0] - assert kwargs["operation_timeout"] == expected_op_timeout - assert kwargs["attempt_timeout"] == expected_req_timeout - assert len(args) == 1 - assert isinstance(args[0], ReadRowsQuery) - query = args[0] - assert query.row_keys == [row_key] - assert query.row_ranges == [] - assert query.limit == 1 - - def test_read_row_w_filter(self): - """Test reading a single row with an added filter""" - with self._make_client() as client: - table = client.get_table("instance", "table") - row_key = b"test_1" - with mock.patch.object(table, "read_rows") as read_rows: - expected_result = object() - read_rows.side_effect = lambda *args, **kwargs: [expected_result] - expected_op_timeout = 8 - expected_req_timeout = 4 - mock_filter = mock.Mock() - expected_filter = {"filter": "mock filter"} - mock_filter._to_dict.return_value = expected_filter - row = table.read_row( - row_key, - operation_timeout=expected_op_timeout, - attempt_timeout=expected_req_timeout, - row_filter=expected_filter, - ) - assert row == expected_result - assert read_rows.call_count == 1 - (args, kwargs) = read_rows.call_args_list[0] - assert kwargs["operation_timeout"] == expected_op_timeout - assert kwargs["attempt_timeout"] == expected_req_timeout - assert len(args) == 1 - assert isinstance(args[0], ReadRowsQuery) - query = args[0] - assert query.row_keys == [row_key] - assert query.row_ranges == [] - assert query.limit == 1 - assert query.filter == expected_filter - - def test_read_row_no_response(self): - """should return None if row does not exist""" - with self._make_client() as client: - table = client.get_table("instance", "table") - row_key = b"test_1" - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = lambda *args, **kwargs: [] - expected_op_timeout = 8 - expected_req_timeout = 4 - result = table.read_row( - row_key, - operation_timeout=expected_op_timeout, - attempt_timeout=expected_req_timeout, - ) - assert result is None - assert read_rows.call_count == 1 - (args, kwargs) = read_rows.call_args_list[0] - assert kwargs["operation_timeout"] == expected_op_timeout - assert kwargs["attempt_timeout"] == expected_req_timeout - assert isinstance(args[0], ReadRowsQuery) - query = args[0] - assert query.row_keys == [row_key] - assert query.row_ranges == [] - assert query.limit == 1 - - @pytest.mark.parametrize( - "return_value,expected_result", - [([], False), ([object()], True), ([object(), object()], True)], - ) - def test_row_exists(self, return_value, expected_result): - """Test checking for row existence""" - with self._make_client() as client: - table = client.get_table("instance", "table") - row_key = b"test_1" - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = lambda *args, **kwargs: return_value - expected_op_timeout = 1 - expected_req_timeout = 2 - result = table.row_exists( - row_key, - operation_timeout=expected_op_timeout, - attempt_timeout=expected_req_timeout, - ) - assert expected_result == result - assert read_rows.call_count == 1 - (args, kwargs) = read_rows.call_args_list[0] - assert kwargs["operation_timeout"] == expected_op_timeout - assert kwargs["attempt_timeout"] == expected_req_timeout - assert isinstance(args[0], ReadRowsQuery) - expected_filter = { - "chain": { - "filters": [ - {"cells_per_row_limit_filter": 1}, - {"strip_value_transformer": True}, - ] - } - } - query = args[0] - assert query.row_keys == [row_key] - assert query.row_ranges == [] - assert query.limit == 1 - assert query.filter._to_dict() == expected_filter - - -class TestReadRowsSharded(ABC): - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - def test_read_rows_sharded_empty_query(self): - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as exc: - table.read_rows_sharded([]) - assert "empty sharded_query" in str(exc.value) - - def test_read_rows_sharded_multiple_queries(self): - """Test with multiple queries. Should return results from both""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - table.client._gapic_client, "read_rows" - ) as read_rows: - read_rows.side_effect = ( - lambda *args, **kwargs: TestReadRows._make_gapic_stream( - [ - TestReadRows._make_chunk(row_key=k) - for k in args[0].rows.row_keys - ] - ) - ) - query_1 = ReadRowsQuery(b"test_1") - query_2 = ReadRowsQuery(b"test_2") - result = table.read_rows_sharded([query_1, query_2]) - assert len(result) == 2 - assert result[0].row_key == b"test_1" - assert result[1].row_key == b"test_2" - - @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24]) - def test_read_rows_sharded_multiple_queries_calls(self, n_queries): - """Each query should trigger a separate read_rows call""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - query_list = [ReadRowsQuery() for _ in range(n_queries)] - table.read_rows_sharded(query_list) - assert read_rows.call_count == n_queries - - def test_read_rows_sharded_errors(self): - """Errors should be exposed as ShardedReadRowsExceptionGroups""" - from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup - from google.cloud.bigtable.data.exceptions import FailedQueryShardError - - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = RuntimeError("mock error") - query_1 = ReadRowsQuery(b"test_1") - query_2 = ReadRowsQuery(b"test_2") - with pytest.raises(ShardedReadRowsExceptionGroup) as exc: - table.read_rows_sharded([query_1, query_2]) - exc_group = exc.value - assert isinstance(exc_group, ShardedReadRowsExceptionGroup) - assert len(exc.value.exceptions) == 2 - assert isinstance(exc.value.exceptions[0], FailedQueryShardError) - assert isinstance(exc.value.exceptions[0].__cause__, RuntimeError) - assert exc.value.exceptions[0].index == 0 - assert exc.value.exceptions[0].query == query_1 - assert isinstance(exc.value.exceptions[1], FailedQueryShardError) - assert isinstance(exc.value.exceptions[1].__cause__, RuntimeError) - assert exc.value.exceptions[1].index == 1 - assert exc.value.exceptions[1].query == query_2 - - def test_read_rows_sharded_concurrent(self): - """Ensure sharded requests are concurrent""" - import time - - def mock_call(*args, **kwargs): - time.sleep(0.1) - return [mock.Mock()] - - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = mock_call - queries = [ReadRowsQuery() for _ in range(10)] - start_time = time.monotonic() - result = table.read_rows_sharded(queries) - call_time = time.monotonic() - start_time - assert read_rows.call_count == 10 - assert len(result) == 10 - assert call_time < 0.2 - - def test_read_rows_sharded_batching(self): - """ - Large queries should be processed in batches to limit concurrency - operation timeout should change between batches - """ - from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT - - assert _CONCURRENCY_LIMIT == 10 - n_queries = 90 - expected_num_batches = n_queries // _CONCURRENCY_LIMIT - query_list = [ReadRowsQuery() for _ in range(n_queries)] - start_operation_timeout = 10 - start_attempt_timeout = 3 - client = self._make_client(use_emulator=True) - table = client.get_table( - "instance", - "table", - default_read_rows_operation_timeout=start_operation_timeout, - default_read_rows_attempt_timeout=start_attempt_timeout, - ) - - def mock_time_generator(start_op, _): - for i in range(0, 100000): - yield (start_op - i) - - with mock.patch( - "google.cloud.bigtable.data._helpers._attempt_timeout_generator" - ) as time_gen_mock: - time_gen_mock.side_effect = mock_time_generator - with mock.patch.object(table, "read_rows", mock.Mock()) as read_rows_mock: - read_rows_mock.return_value = [] - table.read_rows_sharded(query_list) - assert read_rows_mock.call_count == n_queries - kwargs = [ - read_rows_mock.call_args_list[idx][1] for idx in range(n_queries) - ] - for batch_idx in range(expected_num_batches): - batch_kwargs = kwargs[ - batch_idx - * _CONCURRENCY_LIMIT : (batch_idx + 1) - * _CONCURRENCY_LIMIT - ] - for req_kwargs in batch_kwargs: - expected_operation_timeout = start_operation_timeout - batch_idx - assert ( - req_kwargs["operation_timeout"] - == expected_operation_timeout - ) - expected_attempt_timeout = min( - start_attempt_timeout, expected_operation_timeout - ) - assert req_kwargs["attempt_timeout"] == expected_attempt_timeout - - -class TestSampleRowKeys(ABC): - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): - from google.cloud.bigtable_v2.types import SampleRowKeysResponse - - for value in sample_list: - yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) - - def test_sample_row_keys(self): - """Test that method returns the expected key samples""" - samples = [(b"test_1", 0), (b"test_2", 100), (b"test_3", 200)] - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() - ) as sample_row_keys: - sample_row_keys.return_value = self._make_gapic_stream(samples) - result = table.sample_row_keys() - assert len(result) == 3 - assert all((isinstance(r, tuple) for r in result)) - assert all((isinstance(r[0], bytes) for r in result)) - assert all((isinstance(r[1], int) for r in result)) - assert result[0] == samples[0] - assert result[1] == samples[1] - assert result[2] == samples[2] - - def test_sample_row_keys_bad_timeout(self): - """should raise error if timeout is negative""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as e: - table.sample_row_keys(operation_timeout=-1) - assert "operation_timeout must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - table.sample_row_keys(attempt_timeout=-1) - assert "attempt_timeout must be greater than 0" in str(e.value) - - def test_sample_row_keys_default_timeout(self): - """Should fallback to using table default operation_timeout""" - expected_timeout = 99 - with self._make_client() as client: - with client.get_table( - "i", - "t", - default_operation_timeout=expected_timeout, - default_attempt_timeout=expected_timeout, - ) as table: - with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() - ) as sample_row_keys: - sample_row_keys.return_value = self._make_gapic_stream([]) - result = table.sample_row_keys() - (_, kwargs) = sample_row_keys.call_args - assert abs(kwargs["timeout"] - expected_timeout) < 0.1 - assert result == [] - assert kwargs["retry"] is None - - def test_sample_row_keys_gapic_params(self): - """make sure arguments are propagated to gapic call as expected""" - expected_timeout = 10 - expected_profile = "test1" - instance = "instance_name" - table_id = "my_table" - with self._make_client() as client: - with client.get_table( - instance, table_id, app_profile_id=expected_profile - ) as table: - with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() - ) as sample_row_keys: - sample_row_keys.return_value = self._make_gapic_stream([]) - table.sample_row_keys(attempt_timeout=expected_timeout) - (args, kwargs) = sample_row_keys.call_args - assert len(args) == 0 - assert len(kwargs) == 5 - assert kwargs["timeout"] == expected_timeout - assert kwargs["app_profile_id"] == expected_profile - assert kwargs["table_name"] == table.table_name - assert kwargs["metadata"] is not None - assert kwargs["retry"] is None - - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_sample_row_keys_retryable_errors(self, retryable_exception): - """retryable errors should be retried until timeout""" - from google.api_core.exceptions import DeadlineExceeded - from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() - ) as sample_row_keys: - sample_row_keys.side_effect = retryable_exception("mock") - with pytest.raises(DeadlineExceeded) as e: - table.sample_row_keys(operation_timeout=0.05) - cause = e.value.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert len(cause.exceptions) > 0 - assert isinstance(cause.exceptions[0], retryable_exception) - - @pytest.mark.parametrize( - "non_retryable_exception", - [ - core_exceptions.OutOfRange, - core_exceptions.NotFound, - core_exceptions.FailedPrecondition, - RuntimeError, - ValueError, - core_exceptions.Aborted, - ], - ) - def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): - """non-retryable errors should cause a raise""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() - ) as sample_row_keys: - sample_row_keys.side_effect = non_retryable_exception("mock") - with pytest.raises(non_retryable_exception): - table.sample_row_keys() - - -class TestMutateRow(ABC): - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - @pytest.mark.parametrize( - "mutation_arg", - [ - mutations.SetCell("family", b"qualifier", b"value"), - mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=1234567890 - ), - mutations.DeleteRangeFromColumn("family", b"qualifier"), - mutations.DeleteAllFromFamily("family"), - mutations.DeleteAllFromRow(), - [mutations.SetCell("family", b"qualifier", b"value")], - [ - mutations.DeleteRangeFromColumn("family", b"qualifier"), - mutations.DeleteAllFromRow(), - ], - ], - ) - def test_mutate_row(self, mutation_arg): - """Test mutations with no errors""" - expected_attempt_timeout = 19 - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.return_value = None - table.mutate_row( - "row_key", - mutation_arg, - attempt_timeout=expected_attempt_timeout, - ) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0].kwargs - assert ( - kwargs["table_name"] - == "projects/project/instances/instance/tables/table" - ) - assert kwargs["row_key"] == b"row_key" - formatted_mutations = ( - [mutation._to_pb() for mutation in mutation_arg] - if isinstance(mutation_arg, list) - else [mutation_arg._to_pb()] - ) - assert kwargs["mutations"] == formatted_mutations - assert kwargs["timeout"] == expected_attempt_timeout - assert kwargs["retry"] is None - - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_mutate_row_retryable_errors(self, retryable_exception): - from google.api_core.exceptions import DeadlineExceeded - from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(DeadlineExceeded) as e: - mutation = mutations.DeleteAllFromRow() - assert mutation.is_idempotent() is True - table.mutate_row("row_key", mutation, operation_timeout=0.01) - cause = e.value.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert isinstance(cause.exceptions[0], retryable_exception) - - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_mutate_row_non_idempotent_retryable_errors(self, retryable_exception): - """Non-idempotent mutations should not be retried""" - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(retryable_exception): - mutation = mutations.SetCell( - "family", b"qualifier", b"value", -1 - ) - assert mutation.is_idempotent() is False - table.mutate_row("row_key", mutation, operation_timeout=0.2) - - @pytest.mark.parametrize( - "non_retryable_exception", - [ - core_exceptions.OutOfRange, - core_exceptions.NotFound, - core_exceptions.FailedPrecondition, - RuntimeError, - ValueError, - core_exceptions.Aborted, - ], - ) - def test_mutate_row_non_retryable_errors(self, non_retryable_exception): - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = non_retryable_exception("mock") - with pytest.raises(non_retryable_exception): - mutation = mutations.SetCell( - "family", - b"qualifier", - b"value", - timestamp_micros=1234567890, - ) - assert mutation.is_idempotent() is True - table.mutate_row("row_key", mutation, operation_timeout=0.2) - - @pytest.mark.parametrize("include_app_profile", [True, False]) - def test_mutate_row_metadata(self, include_app_profile): - """request should attach metadata headers""" - profile = "profile" if include_app_profile else None - with self._make_client() as client: - with client.get_table("i", "t", app_profile_id=profile) as table: - with mock.patch.object( - client._gapic_client, "mutate_row", mock.Mock() - ) as read_rows: - table.mutate_row("rk", mock.Mock()) - kwargs = read_rows.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata - - @pytest.mark.parametrize("mutations", [[], None]) - def test_mutate_row_no_mutations(self, mutations): - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as e: - table.mutate_row("key", mutations=mutations) - assert e.value.args[0] == "No mutations provided" - - -class TestBulkMutateRows(ABC): - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - def _mock_response(self, response_list): - from google.cloud.bigtable_v2.types import MutateRowsResponse - from google.rpc import status_pb2 - - statuses = [] - for response in response_list: - if isinstance(response, core_exceptions.GoogleAPICallError): - statuses.append( - status_pb2.Status( - message=str(response), code=response.grpc_status_code.value[0] - ) - ) - else: - statuses.append(status_pb2.Status(code=0)) - entries = [ - MutateRowsResponse.Entry(index=i, status=statuses[i]) - for i in range(len(response_list)) - ] - - def generator(): - yield MutateRowsResponse(entries=entries) - - return generator() - - @pytest.mark.parametrize( - "mutation_arg", - [ - [mutations.SetCell("family", b"qualifier", b"value")], - [ - mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=1234567890 - ) - ], - [mutations.DeleteRangeFromColumn("family", b"qualifier")], - [mutations.DeleteAllFromFamily("family")], - [mutations.DeleteAllFromRow()], - [mutations.SetCell("family", b"qualifier", b"value")], - [ - mutations.DeleteRangeFromColumn("family", b"qualifier"), - mutations.DeleteAllFromRow(), - ], - ], - ) - def test_bulk_mutate_rows(self, mutation_arg): - """Test mutations with no errors""" - expected_attempt_timeout = 19 - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.return_value = self._mock_response([None]) - bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg) - table.bulk_mutate_rows( - [bulk_mutation], attempt_timeout=expected_attempt_timeout - ) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args[1] - assert ( - kwargs["table_name"] - == "projects/project/instances/instance/tables/table" - ) - assert kwargs["entries"] == [bulk_mutation._to_pb()] - assert kwargs["timeout"] == expected_attempt_timeout - assert kwargs["retry"] is None - - def test_bulk_mutate_rows_multiple_entries(self): - """Test mutations with no errors""" - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.return_value = self._mock_response([None, None]) - mutation_list = [mutations.DeleteAllFromRow()] - entry_1 = mutations.RowMutationEntry(b"row_key_1", mutation_list) - entry_2 = mutations.RowMutationEntry(b"row_key_2", mutation_list) - table.bulk_mutate_rows([entry_1, entry_2]) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args[1] - assert ( - kwargs["table_name"] - == "projects/project/instances/instance/tables/table" - ) - assert kwargs["entries"][0] == entry_1._to_pb() - assert kwargs["entries"][1] == entry_2._to_pb() - - @pytest.mark.parametrize( - "exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_bulk_mutate_rows_idempotent_mutation_error_retryable(self, exception): - """Individual idempotent mutations should be retried if they fail with a retryable error""" - from google.cloud.bigtable.data.exceptions import ( - RetryExceptionGroup, - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = lambda *a, **k: self._mock_response( - [exception("mock")] - ) - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.DeleteAllFromRow() - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - table.bulk_mutate_rows([entry], operation_timeout=0.05) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert "non-idempotent" not in str(failed_exception) - assert isinstance(failed_exception, FailedMutationEntryError) - cause = failed_exception.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert isinstance(cause.exceptions[0], exception) - assert isinstance( - cause.exceptions[-1], core_exceptions.DeadlineExceeded - ) - - @pytest.mark.parametrize( - "exception", - [ - core_exceptions.OutOfRange, - core_exceptions.NotFound, - core_exceptions.FailedPrecondition, - core_exceptions.Aborted, - ], - ) - def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable(self, exception): - """Individual idempotent mutations should not be retried if they fail with a non-retryable error""" - from google.cloud.bigtable.data.exceptions import ( - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = lambda *a, **k: self._mock_response( - [exception("mock")] - ) - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.DeleteAllFromRow() - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - table.bulk_mutate_rows([entry], operation_timeout=0.05) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert "non-idempotent" not in str(failed_exception) - assert isinstance(failed_exception, FailedMutationEntryError) - cause = failed_exception.__cause__ - assert isinstance(cause, exception) - - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_bulk_mutate_idempotent_retryable_request_errors(self, retryable_exception): - """Individual idempotent mutations should be retried if the request fails with a retryable error""" - from google.cloud.bigtable.data.exceptions import ( - RetryExceptionGroup, - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - table.bulk_mutate_rows([entry], operation_timeout=0.05) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" not in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert isinstance(cause.exceptions[0], retryable_exception) - - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_bulk_mutate_rows_non_idempotent_retryable_errors( - self, retryable_exception - ): - """Non-Idempotent mutations should never be retried""" - from google.cloud.bigtable.data.exceptions import ( - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = lambda *a, **k: self._mock_response( - [retryable_exception("mock")] - ) - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", -1 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is False - table.bulk_mutate_rows([entry], operation_timeout=0.2) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, retryable_exception) - - @pytest.mark.parametrize( - "non_retryable_exception", - [ - core_exceptions.OutOfRange, - core_exceptions.NotFound, - core_exceptions.FailedPrecondition, - RuntimeError, - ValueError, - ], - ) - def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception): - """If the request fails with a non-retryable error, mutations should not be retried""" - from google.cloud.bigtable.data.exceptions import ( - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = non_retryable_exception("mock") - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - table.bulk_mutate_rows([entry], operation_timeout=0.2) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" not in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, non_retryable_exception) - - def test_bulk_mutate_error_index(self): - """Test partial failure, partial success. Errors should be associated with the correct index""" - from google.api_core.exceptions import ( - DeadlineExceeded, - ServiceUnavailable, - FailedPrecondition, - ) - from google.cloud.bigtable.data.exceptions import ( - RetryExceptionGroup, - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = [ - self._mock_response([None, ServiceUnavailable("mock"), None]), - self._mock_response([DeadlineExceeded("mock")]), - self._mock_response([FailedPrecondition("final")]), - ] - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entries = [ - mutations.RowMutationEntry( - f"row_key_{i}".encode(), [mutation] - ) - for i in range(3) - ] - assert mutation.is_idempotent() is True - table.bulk_mutate_rows(entries, operation_timeout=1000) - assert len(e.value.exceptions) == 1 - failed = e.value.exceptions[0] - assert isinstance(failed, FailedMutationEntryError) - assert failed.index == 1 - assert failed.entry == entries[1] - cause = failed.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert len(cause.exceptions) == 3 - assert isinstance(cause.exceptions[0], ServiceUnavailable) - assert isinstance(cause.exceptions[1], DeadlineExceeded) - assert isinstance(cause.exceptions[2], FailedPrecondition) - - def test_bulk_mutate_error_recovery(self): - """If an error occurs, then resolves, no exception should be raised""" - from google.api_core.exceptions import DeadlineExceeded - - with self._make_client(project="project") as client: - table = client.get_table("instance", "table") - with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: - mock_gapic.side_effect = [ - self._mock_response([DeadlineExceeded("mock")]), - self._mock_response([None]), - ] - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entries = [ - mutations.RowMutationEntry(f"row_key_{i}".encode(), [mutation]) - for i in range(3) - ] - table.bulk_mutate_rows(entries, operation_timeout=1000) - - -class TestCheckAndMutateRow(ABC): - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - @pytest.mark.parametrize("gapic_result", [True, False]) - def test_check_and_mutate(self, gapic_result): - from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - - app_profile = "app_profile_id" - with self._make_client() as client: - with client.get_table( - "instance", "table", app_profile_id=app_profile - ) as table: - with mock.patch.object( - client._gapic_client, "check_and_mutate_row" - ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=gapic_result - ) - row_key = b"row_key" - predicate = None - true_mutations = [mock.Mock()] - false_mutations = [mock.Mock(), mock.Mock()] - operation_timeout = 0.2 - found = table.check_and_mutate_row( - row_key, - predicate, - true_case_mutations=true_mutations, - false_case_mutations=false_mutations, - operation_timeout=operation_timeout, - ) - assert found == gapic_result - kwargs = mock_gapic.call_args[1] - assert kwargs["table_name"] == table.table_name - assert kwargs["row_key"] == row_key - assert kwargs["predicate_filter"] == predicate - assert kwargs["true_mutations"] == [ - m._to_pb() for m in true_mutations - ] - assert kwargs["false_mutations"] == [ - m._to_pb() for m in false_mutations - ] - assert kwargs["app_profile_id"] == app_profile - assert kwargs["timeout"] == operation_timeout - assert kwargs["retry"] is None - - def test_check_and_mutate_bad_timeout(self): - """Should raise error if operation_timeout < 0""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as e: - table.check_and_mutate_row( - b"row_key", - None, - true_case_mutations=[mock.Mock()], - false_case_mutations=[], - operation_timeout=-1, - ) - assert str(e.value) == "operation_timeout must be greater than 0" - - def test_check_and_mutate_single_mutations(self): - """if single mutations are passed, they should be internally wrapped in a list""" - from google.cloud.bigtable.data.mutations import SetCell - from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "check_and_mutate_row" - ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True - ) - true_mutation = SetCell("family", b"qualifier", b"value") - false_mutation = SetCell("family", b"qualifier", b"value") - table.check_and_mutate_row( - b"row_key", - None, - true_case_mutations=true_mutation, - false_case_mutations=false_mutation, - ) - kwargs = mock_gapic.call_args[1] - assert kwargs["true_mutations"] == [true_mutation._to_pb()] - assert kwargs["false_mutations"] == [false_mutation._to_pb()] - - def test_check_and_mutate_predicate_object(self): - """predicate filter should be passed to gapic request""" - from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - - mock_predicate = mock.Mock() - predicate_pb = {"predicate": "dict"} - mock_predicate._to_pb.return_value = predicate_pb - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "check_and_mutate_row" - ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True - ) - table.check_and_mutate_row( - b"row_key", mock_predicate, false_case_mutations=[mock.Mock()] - ) - kwargs = mock_gapic.call_args[1] - assert kwargs["predicate_filter"] == predicate_pb - assert mock_predicate._to_pb.call_count == 1 - assert kwargs["retry"] is None - - def test_check_and_mutate_mutations_parsing(self): - """mutations objects should be converted to protos""" - from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - from google.cloud.bigtable.data.mutations import DeleteAllFromRow - - mutations = [mock.Mock() for _ in range(5)] - for idx, mutation in enumerate(mutations): - mutation._to_pb.return_value = f"fake {idx}" - mutations.append(DeleteAllFromRow()) - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "check_and_mutate_row" - ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True - ) - table.check_and_mutate_row( - b"row_key", - None, - true_case_mutations=mutations[0:2], - false_case_mutations=mutations[2:], - ) - kwargs = mock_gapic.call_args[1] - assert kwargs["true_mutations"] == ["fake 0", "fake 1"] - assert kwargs["false_mutations"] == [ - "fake 2", - "fake 3", - "fake 4", - DeleteAllFromRow()._to_pb(), - ] - assert all( - (mutation._to_pb.call_count == 1 for mutation in mutations[:5]) - ) - - -class TestReadModifyWriteRow(ABC): - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - @pytest.mark.parametrize( - "call_rules,expected_rules", - [ - ( - AppendValueRule("f", "c", b"1"), - [AppendValueRule("f", "c", b"1")._to_pb()], - ), - ( - [AppendValueRule("f", "c", b"1")], - [AppendValueRule("f", "c", b"1")._to_pb()], - ), - (IncrementRule("f", "c", 1), [IncrementRule("f", "c", 1)._to_pb()]), - ( - [AppendValueRule("f", "c", b"1"), IncrementRule("f", "c", 1)], - [ - AppendValueRule("f", "c", b"1")._to_pb(), - IncrementRule("f", "c", 1)._to_pb(), - ], - ), - ], - ) - def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): - """Test that the gapic call is called with given rules""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - table.read_modify_write_row("key", call_rules) - assert mock_gapic.call_count == 1 - found_kwargs = mock_gapic.call_args_list[0][1] - assert found_kwargs["rules"] == expected_rules - assert found_kwargs["retry"] is None - - @pytest.mark.parametrize("rules", [[], None]) - def test_read_modify_write_no_rules(self, rules): - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as e: - table.read_modify_write_row("key", rules=rules) - assert e.value.args[0] == "rules must contain at least one item" - - def test_read_modify_write_call_defaults(self): - instance = "instance1" - table_id = "table1" - project = "project1" - row_key = "row_key1" - with self._make_client(project=project) as client: - with client.get_table(instance, table_id) as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - table.read_modify_write_row(row_key, mock.Mock()) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0][1] - assert ( - kwargs["table_name"] - == f"projects/{project}/instances/{instance}/tables/{table_id}" - ) - assert kwargs["app_profile_id"] is None - assert kwargs["row_key"] == row_key.encode() - assert kwargs["timeout"] > 1 - - def test_read_modify_write_call_overrides(self): - row_key = b"row_key1" - expected_timeout = 12345 - profile_id = "profile1" - with self._make_client() as client: - with client.get_table( - "instance", "table_id", app_profile_id=profile_id - ) as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - table.read_modify_write_row( - row_key, mock.Mock(), operation_timeout=expected_timeout - ) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0][1] - assert kwargs["app_profile_id"] is profile_id - assert kwargs["row_key"] == row_key - assert kwargs["timeout"] == expected_timeout - - def test_read_modify_write_string_key(self): - row_key = "string_row_key1" - with self._make_client() as client: - with client.get_table("instance", "table_id") as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - table.read_modify_write_row(row_key, mock.Mock()) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0][1] - assert kwargs["row_key"] == row_key.encode() - - def test_read_modify_write_row_building(self): - """results from gapic call should be used to construct row""" - from google.cloud.bigtable.data.row import Row - from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse - from google.cloud.bigtable_v2.types import Row as RowPB - - mock_response = ReadModifyWriteRowResponse(row=RowPB()) - with self._make_client() as client: - with client.get_table("instance", "table_id") as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - with mock.patch.object(Row, "_from_pb") as constructor_mock: - mock_gapic.return_value = mock_response - table.read_modify_write_row("key", mock.Mock()) - assert constructor_mock.call_count == 1 - constructor_mock.assert_called_once_with(mock_response.row) - - -class TestReadRowsAcceptance(ABC): - @staticmethod - def _get_operation_class(): - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - - return _ReadRowsOperation - - @staticmethod - def _get_client_class(): - from google.cloud.bigtable.data._sync.client import BigtableDataClient - - return BigtableDataClient - - def parse_readrows_acceptance_tests(): - dirname = os.path.dirname(__file__) - filename = os.path.join(dirname, "../read-rows-acceptance-test.json") - with open(filename) as json_file: - test_json = TestFile.from_json(json_file.read()) - return test_json.read_rows_tests - - @staticmethod - def extract_results_from_row(row: Row): - results = [] - for family, col, cells in row.items(): - for cell in cells: - results.append( - ReadRowsTest.Result( - row_key=row.row_key, - family_name=family, - qualifier=col, - timestamp_micros=cell.timestamp_ns // 1000, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - ) - return results - - @staticmethod - def _coro_wrapper(stream): - return stream - - def _process_chunks(self, *chunks): - def _row_stream(): - yield ReadRowsResponse(chunks=chunks) - - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = None - chunker = self._get_operation_class().chunk_stream( - instance, self._coro_wrapper(_row_stream()) - ) - merger = self._get_operation_class().merge_rows(chunker) - results = [] - for row in merger: - results.append(row) - return results - - @pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description - ) - def test_row_merger_scenario(self, test_case: ReadRowsTest): - def _scenerio_stream(): - for chunk in test_case.chunks: - yield ReadRowsResponse(chunks=[chunk]) - - try: - results = [] - instance = mock.Mock() - instance._last_yielded_row_key = None - instance._remaining_count = None - chunker = self._get_operation_class().chunk_stream( - instance, self._coro_wrapper(_scenerio_stream()) - ) - merger = self._get_operation_class().merge_rows(chunker) - for row in merger: - for cell in row: - cell_result = ReadRowsTest.Result( - row_key=cell.row_key, - family_name=cell.family, - qualifier=cell.qualifier, - timestamp_micros=cell.timestamp_micros, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - @pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description - ) - def test_read_rows_scenario(self, test_case: ReadRowsTest): - def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): - from google.cloud.bigtable_v2 import ReadRowsResponse - - class mock_stream: - def __init__(self, chunk_list): - self.chunk_list = chunk_list - self.idx = -1 - - def __iter__(self): - return self - - def __next__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - chunk = self.chunk_list[self.idx] - return ReadRowsResponse(chunks=[chunk]) - raise StopIteration - - def cancel(self): - pass - - return mock_stream(chunk_list) - - with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - client = self._get_client_class()() - try: - table = client.get_table("instance", "table") - results = [] - with mock.patch.object( - table.client._gapic_client, "read_rows" - ) as read_rows: - read_rows.return_value = _make_gapic_stream(test_case.chunks) - for row in table.read_rows_stream(query={}): - for cell in row: - cell_result = ReadRowsTest.Result( - row_key=cell.row_key, - family_name=cell.family, - qualifier=cell.qualifier, - timestamp_micros=cell.timestamp_micros, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - finally: - client.close() - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - def test_out_of_order_rows(self): - def _row_stream(): - yield ReadRowsResponse(last_scanned_row_key=b"a") - - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = b"b" - chunker = self._get_operation_class().chunk_stream( - instance, self._coro_wrapper(_row_stream()) - ) - merger = self._get_operation_class().merge_rows(chunker) - with pytest.raises(InvalidChunk): - for _ in merger: - pass - - def test_bare_reset(self): - first_chunk = ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk( - row_key=b"a", family_name="f", qualifier=b"q", value=b"v" - ) - ) - with pytest.raises(InvalidChunk): - self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") - ), - ) - with pytest.raises(InvalidChunk): - self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, family_name="f") - ), - ) - with pytest.raises(InvalidChunk): - self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") - ), - ) - with pytest.raises(InvalidChunk): - self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) - ), - ) - with pytest.raises(InvalidChunk): - self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) - ), - ) - with pytest.raises(InvalidChunk): - self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, value=b"v") - ), - ) - - def test_missing_family(self): - with pytest.raises(InvalidChunk): - self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - qualifier=b"q", - timestamp_micros=1000, - value=b"v", - commit_row=True, - ) - ) - - def test_mid_cell_row_key_change(self): - with pytest.raises(InvalidChunk): - self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), - ) - - def test_mid_cell_family_change(self): - with pytest.raises(InvalidChunk): - self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - family_name="f2", value=b"v", commit_row=True - ), - ) - - def test_mid_cell_qualifier_change(self): - with pytest.raises(InvalidChunk): - self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - qualifier=b"q2", value=b"v", commit_row=True - ), - ) - - def test_mid_cell_timestamp_change(self): - with pytest.raises(InvalidChunk): - self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - timestamp_micros=2000, value=b"v", commit_row=True - ), - ) - - def test_mid_cell_labels_change(self): - with pytest.raises(InvalidChunk): - self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), - ) From c59eec2982d40cd2de435a30aca17d26e557bcc0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 16:50:39 -0700 Subject: [PATCH 093/360] got unit tests passing --- google/cloud/bigtable/data/_sync/cross_sync.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 9722cf4e8..0281e29a0 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -14,8 +14,7 @@ # from __future__ import annotations -from typing import TypeVar, Any, Awaitable, Callable, Coroutine, Sequence -from typing_extensions import TypeAlias +from typing import TypeVar, Any, Awaitable, Callable, Coroutine, Sequence, TYPE_CHECKING import asyncio import sys @@ -25,6 +24,9 @@ import threading import queue +if TYPE_CHECKING: + from typing_extensions import TypeAlias + T = TypeVar("T") From 8ef90472d6e5f23b2f7e73bba3ab6eee532f0934 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 17:17:56 -0700 Subject: [PATCH 094/360] updated sync files --- .../cloud/bigtable/data/_sync/_mutate_rows.py | 29 +- .../cloud/bigtable/data/_sync/_read_rows.py | 49 ++- google/cloud/bigtable/data/_sync/client.py | 324 ++++++++---------- .../bigtable/data/_sync/mutations_batcher.py | 147 ++++---- 4 files changed, 283 insertions(+), 266 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index c368db229..37849eca4 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -45,6 +45,14 @@ class _MutateRowsOperation: Errors are exposed as a MutationsExceptionGroup, which contains a list of exceptions organized by the related failed mutation entries. + + Args: + gapic_client: the client to use for the mutate_rows call + table: the table associated with the request + mutation_entries: a list of RowMutationEntry objects to send to the server + operation_timeout: the timeout to use for the entire operation, in seconds. + attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds. + If not specified, the request will run until operation_timeout is reached. """ def __init__( @@ -56,15 +64,6 @@ def __init__( attempt_timeout: float | None, retryable_exceptions: Sequence[type[Exception]] = (), ): - """ - Args: - - gapic_client: the client to use for the mutate_rows call - - table: the table associated with the request - - mutation_entries: a list of RowMutationEntry objects to send to the server - - operation_timeout: the timeout to use for the entire operation, in seconds. - - attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds. - If not specified, the request will run until operation_timeout is reached. - """ total_mutations = sum((len(entry.mutations) for entry in mutation_entries)) if total_mutations > _MUTATE_ROWS_REQUEST_MUTATION_LIMIT: raise ValueError( @@ -101,7 +100,7 @@ def start(self): Start the operation, and run until completion Raises: - - MutationsExceptionGroup: if any mutations failed + MutationsExceptionGroup: if any mutations failed """ try: self._operation() @@ -130,9 +129,9 @@ def _run_attempt(self): Run a single attempt of the mutate_rows rpc. Raises: - - _MutateRowsIncomplete: if there are failed mutations eligible for - retry after the attempt is complete - - GoogleAPICallError: if the gapic rpc fails + _MutateRowsIncomplete: if there are failed mutations eligible for + retry after the attempt is complete + GoogleAPICallError: if the gapic rpc fails """ request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] active_request_indices = { @@ -174,8 +173,8 @@ def _handle_entry_error(self, idx: int, exc: Exception): retryable. Args: - - idx: the index of the mutation that failed - - exc: the exception to add to the list + idx: the index of the mutation that failed + exc: the exception to add to the list """ entry = self.mutations[idx].entry self.errors.setdefault(idx, []).append(exc) diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index 6225414ea..d43782a8f 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -47,6 +47,13 @@ class _ReadRowsOperation: ReadRowsOperation(request, client) handles row merging logic end-to-end, including performing retries on stream errors. + + Args: + query: The query to execute + table: The table to send the request to + operation_timeout: The total time to allow for the operation, in seconds + attempt_timeout: The time to allow for each individual attempt, in seconds + retryable_exceptions: A list of exceptions that should trigger a retry """ __slots__ = ( @@ -87,7 +94,12 @@ def __init__( self._remaining_count: int | None = self.request.rows_limit or None def start_operation(self) -> Iterable[Row]: - """Start the read_rows operation, retrying on retryable errors.""" + """ + Start the read_rows operation, retrying on retryable errors. + + Yields: + Row: The next row in the stream + """ return CrossSync._Sync_Impl.retry_target_stream( self._read_rows_attempt, self._predicate, @@ -102,6 +114,9 @@ def _read_rows_attempt(self) -> Iterable[Row]: This function is intended to be wrapped by retry logic, which will call this function until it succeeds or a non-retryable error is raised. + + Yields: + Row: The next row in the stream """ if self._last_yielded_row_key is not None: try: @@ -127,7 +142,14 @@ def _read_rows_attempt(self) -> Iterable[Row]: def chunk_stream( self, stream: Iterable[ReadRowsResponsePB] ) -> Iterable[ReadRowsResponsePB.CellChunk]: - """process chunks out of raw read_rows stream""" + """ + process chunks out of raw read_rows stream + + Args: + stream: the raw read_rows stream from the gapic client + Yields: + ReadRowsResponsePB.CellChunk: the next chunk in the stream + """ for resp in stream: resp = resp._pb if resp.last_scanned_row_key: @@ -160,8 +182,17 @@ def chunk_stream( current_key = None @staticmethod - def merge_rows(chunks: Iterable[ReadRowsResponsePB.CellChunk] | None): - """Merge chunks into rows""" + def merge_rows( + chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None + ) -> AsyncGenerator[Row, None]: + """ + Merge chunks into rows + + Args: + chunks: the chunk stream to merge + Yields: + Row: the next row in the stream + """ if chunks is None: return it = chunks.__iter__() @@ -257,12 +288,14 @@ def _revise_request_rowset(row_set: RowSetPB, last_seen_row_key: bytes) -> RowSe Revise the rows in the request to avoid ones we've already processed. Args: - - row_set: the row set from the request - - last_seen_row_key: the last row key encountered + row_set: the row set from the request + last_seen_row_key: the last row key encountered + Returns: + RowSetPB: the new rowset after adusting for the last seen key Raises: - - _RowSetComplete: if there are no rows left to process after the revision + _RowSetComplete: if there are no rows left to process after the revision """ - if row_set is None or (not row_set.row_ranges and row_set.row_keys is not None): + if row_set is None or (not row_set.row_ranges and (not row_set.row_keys)): last_seen = last_seen_row_key return RowSetPB(row_ranges=[RowRangePB(start_key_open=last_seen)]) adjusted_keys: list[bytes] = [ diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 975b379a2..741ea603d 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -91,9 +91,6 @@ def __init__( Client should be created within an async context (running event loop) - Warning: BigtableDataClientAsync is currently in preview, and is not - yet recommended for production use. - Args: project: the project which the client acts on behalf of. If not passed, falls back to the default inferred @@ -105,12 +102,12 @@ def __init__( client. If not passed (and if no ``_http`` object is passed), falls back to the default inferred from the environment. - client_options (Optional[Union[dict, google.api_core.client_options.ClientOptions]]): + client_options: Client options used to set user options on the client. API Endpoint should be set through client_options. Raises: - - RuntimeError if called outside of an async context (no running event loop) - - ValueError if pool_size is less than 1 + RuntimeError: if called outside of an async context (no running event loop) + ValueError: if pool_size is less than 1 """ transport_str = f"bt-{self._client_version()}-{pool_size}" transport = PooledBigtableGrpcTransport.with_fixed_size(pool_size) @@ -183,8 +180,9 @@ def _client_version() -> str: def _start_background_channel_refresh(self) -> None: """ Starts a background task to ping and warm each channel in the pool + Raises: - - RuntimeError if not called in an asyncio event loop + RuntimeError: if not called in an asyncio event loop """ if ( not self._channel_refresh_tasks @@ -225,10 +223,10 @@ def _ping_and_warm_instances( Pings each Bigtable instance registered in `_active_instances` on the client Args: - - channel: grpc channel to warm - - instance_key: if provided, only warm the instance associated with the key + channel: grpc channel to warm + instance_key: if provided, only warm the instance associated with the key Returns: - - sequence of results or exceptions from the ping requests + list[BaseException | None]: sequence of results or exceptions from the ping requests """ instance_list = ( [instance_key] if instance_key is not None else self._active_instances @@ -316,10 +314,10 @@ def _register_instance(self, instance_id: str, owner: Table) -> None: Channels will not be refreshed unless at least one instance is registered Args: - - instance_id: id of the instance to register. - - owner: table that owns the instance. Owners will be tracked in - _instance_owners, and instances will only be unregistered when all - owners call _remove_instance_registration + instance_id: id of the instance to register. + owner: table that owns the instance. Owners will be tracked in + _instance_owners, and instances will only be unregistered when all + owners call _remove_instance_registration """ instance_name = self._gapic_client.instance_path(self.project, instance_id) instance_key = _helpers._WarmedInstanceKey( @@ -342,12 +340,12 @@ def _remove_instance_registration(self, instance_id: str, owner: Table) -> bool: If instance_id is not registered, or is still in use by other tables, returns False Args: - - instance_id: id of the instance to remove - - owner: table that owns the instance. Owners will be tracked in + instance_id: id of the instance to remove + owner: table that owns the instance. Owners will be tracked in _instance_owners, and instances will only be unregistered when all owners call _remove_instance_registration Returns: - - True if instance was removed + bool: True if instance was removed, else False """ instance_name = self._gapic_client.instance_path(self.project, instance_id) instance_key = _helpers._WarmedInstanceKey( @@ -396,6 +394,10 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> Table: default_retryable_errors: a list of errors that will be retried if encountered during all other operations. Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + Returns: + TableAsync: a table instance for making data API requests + Raises: + RuntimeError: if called outside of an async context (no running event loop) """ return Table(self, instance_id, table_id, *args, **kwargs) @@ -478,7 +480,7 @@ def __init__( encountered during all other operations. Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) Raises: - - RuntimeError if called outside of an async context (no running event loop) + RuntimeError: if called outside of an async context (no running event loop) """ _helpers._validate_timeouts( default_operation_timeout, default_attempt_timeout, allow_none=True @@ -546,28 +548,25 @@ def read_rows_stream( Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. - Warning: BigtableDataClient is currently in preview, and is not - yet recommended for production use. - Args: - - query: contains details about which rows to return - - operation_timeout: the time budget for the entire operation, in seconds. + query: contains details about which rows to return + operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget. Defaults to the Table's default_read_rows_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. + attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_read_rows_attempt_timeout. If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. + retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_read_rows_retryable_errors Returns: - - an asynchronous iterator that yields rows returned by the query + AsyncIterable[Row]: an asynchronous iterator that yields rows returned by the query Raises: - - DeadlineExceeded: raised after operation timeout + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error """ operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self @@ -599,30 +598,27 @@ def read_rows( Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. - Warning: BigtableDataClient is currently in preview, and is not - yet recommended for production use. - Args: - - query: contains details about which rows to return - - operation_timeout: the time budget for the entire operation, in seconds. + query: contains details about which rows to return + operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget. Defaults to the Table's default_read_rows_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. + attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_read_rows_attempt_timeout. If None, defaults to operation_timeout. If None, defaults to the Table's default_read_rows_attempt_timeout, or the operation_timeout if that is also None. - - retryable_errors: a list of errors that will be retried if encountered. + retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_read_rows_retryable_errors. Returns: - - a list of Rows returned by the query + list[Row]: a list of Rows returned by the query Raises: - - DeadlineExceeded: raised after operation timeout + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error """ row_generator = self.read_rows_stream( query, @@ -648,28 +644,25 @@ def read_row( Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. - Warning: BigtableDataClient is currently in preview, and is not - yet recommended for production use. - Args: - - query: contains details about which rows to return - - operation_timeout: the time budget for the entire operation, in seconds. + query: contains details about which rows to return + operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget. Defaults to the Table's default_read_rows_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. + attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_read_rows_attempt_timeout. If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. + retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_read_rows_retryable_errors. Returns: - - a Row object if the row exists, otherwise None + Row | None: a Row object if the row exists, otherwise None Raises: - - DeadlineExceeded: raised after operation timeout + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error """ if row_key is None: raise ValueError("row_key must be string or bytes") @@ -697,74 +690,69 @@ def read_rows_sharded( Runs a sharded query in parallel, then return the results in a single list. Results will be returned in the order of the input queries. - This function is intended to be run on the results on a query.shard() call: + This function is intended to be run on the results on a query.shard() call. + For example:: - ``` - table_shard_keys = await table.sample_row_keys() - query = ReadRowsQuery(...) - shard_queries = query.shard(table_shard_keys) - results = await table.read_rows_sharded(shard_queries) - ``` - - Warning: BigtableDataClient is currently in preview, and is not - yet recommended for production use. + table_shard_keys = await table.sample_row_keys() + query = ReadRowsQuery(...) + shard_queries = query.shard(table_shard_keys) + results = await table.read_rows_sharded(shard_queries) Args: - - sharded_query: a sharded query to execute - - operation_timeout: the time budget for the entire operation, in seconds. + sharded_query: a sharded query to execute + operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget. Defaults to the Table's default_read_rows_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. + attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_read_rows_attempt_timeout. If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. + retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_read_rows_retryable_errors. + Returns: + list[Row]: a list of Rows returned by the query Raises: - - ShardedReadRowsExceptionGroup: if any of the queries failed - - ValueError: if the query_list is empty + ShardedReadRowsExceptionGroup: if any of the queries failed + ValueError: if the query_list is empty """ if not sharded_query: raise ValueError("empty sharded_query") operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) - timeout_generator = _helpers._attempt_timeout_generator( + rpc_timeout_generator = _helpers._attempt_timeout_generator( operation_timeout, operation_timeout ) - batched_queries = [ - sharded_query[i : i + _helpers._CONCURRENCY_LIMIT] - for i in range(0, len(sharded_query), _helpers._CONCURRENCY_LIMIT) - ] - results_list = [] - error_dict = {} - shard_idx = 0 - for batch in batched_queries: - batch_operation_timeout = next(timeout_generator) - batch_partial_list = [ - partial( - self.read_rows, - query=query, - operation_timeout=batch_operation_timeout, - attempt_timeout=min(attempt_timeout, batch_operation_timeout), + concurrency_sem = asyncio.Semaphore(_helpers._CONCURRENCY_LIMIT) + + def read_rows_with_semaphore(query): + with concurrency_sem: + shard_timeout = next(rpc_timeout_generator) + if shard_timeout <= 0: + raise DeadlineExceeded( + "Operation timeout exceeded before starting query" + ) + return self.read_rows( + query, + operation_timeout=shard_timeout, + attempt_timeout=min(attempt_timeout, shard_timeout), retryable_errors=retryable_errors, ) - for query in batch - ] - batch_result = CrossSync._Sync_Impl.gather_partials( - batch_partial_list, - return_exceptions=True, - sync_executor=self.client._executor, - ) - for result in batch_result: - if isinstance(result, Exception): - error_dict[shard_idx] = result - elif isinstance(result, BaseException): - raise result - else: - results_list.extend(result) - shard_idx += 1 + + routine_list = [read_rows_with_semaphore(query) for query in sharded_query] + batch_result = asyncio.gather(*routine_list, return_exceptions=True) + error_dict = {} + shard_idx = 0 + results_list = [] + for result in batch_result: + if isinstance(result, Exception): + error_dict[shard_idx] = result + elif isinstance(result, BaseException): + raise result + else: + results_list.extend(result) + shard_idx += 1 if error_dict: raise ShardedReadRowsExceptionGroup( [ @@ -789,28 +777,25 @@ def row_exists( Return a boolean indicating whether the specified row exists in the table. uses the filters: chain(limit cells per row = 1, strip value) - Warning: BigtableDataClient is currently in preview, and is not - yet recommended for production use. - Args: - - row_key: the key of the row to check - - operation_timeout: the time budget for the entire operation, in seconds. + row_key: the key of the row to check + operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget. Defaults to the Table's default_read_rows_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. + attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_read_rows_attempt_timeout. If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. + retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_read_rows_retryable_errors. Returns: - - a bool indicating whether the row exists + bool: a bool indicating whether the row exists Raises: - - DeadlineExceeded: raised after operation timeout + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error """ if row_key is None: raise ValueError("row_key must be string or bytes") @@ -843,29 +828,26 @@ def sample_row_keys( requests will call sample_row_keys internally for this purpose when sharding is enabled RowKeySamples is simply a type alias for list[tuple[bytes, int]]; a list of - row_keys, along with offset positions in the table - - Warning: BigtableDataClient is currently in preview, and is not - yet recommended for production use. + row_keys, along with offset positions in the table Args: - - operation_timeout: the time budget for the entire operation, in seconds. + operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget.i Defaults to the Table's default_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. + attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_attempt_timeout. If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. + retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_retryable_errors. Returns: - - a set of RowKeySamples the delimit contiguous sections of the table + RowKeySamples: a set of RowKeySamples the delimit contiguous sections of the table Raises: - - DeadlineExceeded: raised after operation timeout + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error """ operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self @@ -915,26 +897,23 @@ def mutations_batcher( Can be used to iteratively add mutations that are flushed as a group, to avoid excess network calls - Warning: BigtableDataClient is currently in preview, and is not - yet recommended for production use. - Args: - - flush_interval: Automatically flush every flush_interval seconds. If None, + flush_interval: Automatically flush every flush_interval seconds. If None, a table default will be used - - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count + flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count mutations are added across all entries. If None, this limit is ignored. - - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. - - flow_control_max_mutation_count: Maximum number of inflight mutations. - - flow_control_max_bytes: Maximum number of inflight bytes. - - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. + flow_control_max_mutation_count: Maximum number of inflight mutations. + flow_control_max_bytes: Maximum number of inflight bytes. + batch_operation_timeout: timeout for each mutate_rows operation, in seconds. Defaults to the Table's default_mutate_rows_operation_timeout - - batch_attempt_timeout: timeout for each individual request, in seconds. + batch_attempt_timeout: timeout for each individual request, in seconds. Defaults to the Table's default_mutate_rows_attempt_timeout. If None, defaults to batch_operation_timeout. - - batch_retryable_errors: a list of errors that will be retried if encountered. + batch_retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_mutate_rows_retryable_errors. Returns: - - a MutationsBatcher context manager that can batch requests + MutationsBatcherAsync: a MutationsBatcher context manager that can batch requests """ return MutationsBatcher( self, @@ -967,30 +946,27 @@ def mutate_row( Idempotent operations (i.e, all mutations have an explicit timestamp) will be retried on server failure. Non-idempotent operations will not. - Warning: BigtableDataClient is currently in preview, and is not - yet recommended for production use. - Args: - - row_key: the row to apply mutations to - - mutations: the set of mutations to apply to the row - - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_attempt_timeout. - If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. - Only idempotent mutations will be retried. Defaults to the Table's - default_retryable_errors. + row_key: the row to apply mutations to + mutations: the set of mutations to apply to the row + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Only idempotent mutations will be retried. Defaults to the Table's + default_retryable_errors. Raises: - - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing all - GoogleAPIError exceptions from any retries that failed - - GoogleAPIError: raised on non-idempotent operations that cannot be - safely retried. - - ValueError if invalid arguments are provided + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing all + GoogleAPIError exceptions from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised on non-idempotent operations that cannot be + safely retried. + ValueError: if invalid arguments are provided """ operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self @@ -1044,27 +1020,24 @@ def bulk_mutate_rows( will be retried on failure. Non-idempotent will not, and will reported in a raised exception group - Warning: BigtableDataClient is currently in preview, and is not - yet recommended for production use. - Args: - - mutation_entries: the batches of mutations to apply + mutation_entries: the batches of mutations to apply Each entry will be applied atomically, but entries will be applied in arbitrary order - - operation_timeout: the time budget for the entire operation, in seconds. + operation_timeout: the time budget for the entire operation, in seconds. Failed requests will be retried within the budget. Defaults to the Table's default_mutate_rows_operation_timeout - - attempt_timeout: the time budget for an individual network request, in seconds. + attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_mutate_rows_attempt_timeout. If None, defaults to operation_timeout. - - retryable_errors: a list of errors that will be retried if encountered. + retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_mutate_rows_retryable_errors Raises: - - MutationsExceptionGroup if one or more mutations fails + MutationsExceptionGroup: if one or more mutations fails Contains details about any failed entries in .exceptions - - ValueError if invalid arguments are provided + ValueError: if invalid arguments are provided """ operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self @@ -1094,35 +1067,32 @@ def check_and_mutate_row( Non-idempotent operation: will not be retried - Warning: BigtableDataClient is currently in preview, and is not - yet recommended for production use. - Args: - - row_key: the key of the row to mutate - - predicate: the filter to be applied to the contents of the specified row. + row_key: the key of the row to mutate + predicate: the filter to be applied to the contents of the specified row. Depending on whether or not any results are yielded, either true_case_mutations or false_case_mutations will be executed. If None, checks that the row contains any values at all. - - true_case_mutations: + true_case_mutations: Changes to be atomically applied to the specified row if predicate yields at least one cell when applied to row_key. Entries are applied in order, meaning that earlier mutations can be masked by later ones. Must contain at least one entry if false_case_mutations is empty, and at most 100000. - - false_case_mutations: + false_case_mutations: Changes to be atomically applied to the specified row if predicate_filter does not yield any cells when applied to row_key. Entries are applied in order, meaning that earlier mutations can be masked by later ones. Must contain at least one entry if - `true_case_mutations is empty, and at most 100000. - - operation_timeout: the time budget for the entire operation, in seconds. + `true_case_mutations` is empty, and at most 100000. + operation_timeout: the time budget for the entire operation, in seconds. Failed requests will not be retried. Defaults to the Table's default_operation_timeout Returns: - - bool indicating whether the predicate was true or false + bool indicating whether the predicate was true or false Raises: - - GoogleAPIError exceptions from grpc call + google.api_core.exceptions.GoogleAPIError: exceptions from grpc call """ operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) if true_case_mutations is not None and ( @@ -1165,23 +1135,19 @@ def read_modify_write_row( Non-idempotent operation: will not be retried - Warning: BigtableDataClient is currently in preview, and is not - yet recommended for production use. - Args: - - row_key: the key of the row to apply read/modify/write rules to - - rules: A rule or set of rules to apply to the row. + row_key: the key of the row to apply read/modify/write rules to + rules: A rule or set of rules to apply to the row. Rules are applied in order, meaning that earlier rules will affect the results of later ones. - - operation_timeout: the time budget for the entire operation, in seconds. + operation_timeout: the time budget for the entire operation, in seconds. Failed requests will not be retried. Defaults to the Table's default_operation_timeout. Returns: - - Row: containing cell data that was modified as part of the - operation + Row: a Row containing cell data that was modified as part of the operation Raises: - - GoogleAPIError exceptions from grpc call - - ValueError if invalid arguments are provided + google.api_core.exceptions.GoogleAPIError: exceptions from grpc call + ValueError: if invalid arguments are provided """ operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) if operation_timeout <= 0: diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 8c1dc4904..8a38af39c 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -43,15 +43,28 @@ class MutationsBatcher: Runs mutate_row, mutate_rows, and check_and_mutate_row internally, combining to use as few network requests as required - Flushes: - - every flush_interval seconds - - after queue reaches flush_count in quantity - - after queue reaches flush_size_bytes in storage size - - when batcher is closed or destroyed - - async with table.mutations_batcher() as batcher: - for i in range(10): - batcher.add(row, mut) + Will automatically flush the batcher: + - every flush_interval seconds + - after queue size reaches flush_limit_mutation_count + - after queue reaches flush_limit_bytes + - when batcher is closed or destroyed + + Args: + table: Table to preform rpc calls + flush_interval: Automatically flush every flush_interval seconds. + If None, no time-based flushing is performed. + flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count + mutations are added across all entries. If None, this limit is ignored. + flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. + flow_control_max_mutation_count: Maximum number of inflight mutations. + flow_control_max_bytes: Maximum number of inflight bytes. + batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_operation_timeout. + batch_attempt_timeout: timeout for each individual request, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to batch_operation_timeout. + batch_retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors. """ def __init__( @@ -68,24 +81,6 @@ def __init__( batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, ): - """ - Args: - - table: Table to preform rpc calls - - flush_interval: Automatically flush every flush_interval seconds. - If None, no time-based flushing is performed. - - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count - mutations are added across all entries. If None, this limit is ignored. - - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. - - flow_control_max_mutation_count: Maximum number of inflight mutations. - - flow_control_max_bytes: Maximum number of inflight bytes. - - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. - If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_operation_timeout. - - batch_attempt_timeout: timeout for each individual request, in seconds. - If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_attempt_timeout. - If None, defaults to batch_operation_timeout. - - batch_retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_mutate_rows_retryable_errors. - """ self._operation_timeout, self._attempt_timeout = _get_timeouts( batch_operation_timeout, batch_attempt_timeout, table ) @@ -125,13 +120,22 @@ def __init__( def _timer_routine(self, interval: float | None) -> None: """ - Triggers new flush tasks every `interval` seconds - Ends when the batcher is closed + Set up a background task to flush the batcher every interval seconds + + If interval is None, an empty future is returned + + Args: + flush_interval: Automatically flush every flush_interval seconds. + If None, no time-based flushing is performed. + Returns: + asyncio.Future[None]: future representing the background task """ if not interval or interval <= 0: return None while not self._closed.is_set(): - CrossSync._Sync_Impl.event_wait(self._closed, timeout=interval) + CrossSync._Sync_Impl.event_wait( + self._closed, timeout=interval, async_break_early=False + ) if not self._closed.is_set() and self._staged_entries: self._schedule_flush() @@ -139,13 +143,11 @@ def append(self, mutation_entry: RowMutationEntry): """ Add a new set of mutations to the internal queue - TODO: return a future to track completion of this entry - Args: - - mutation_entry: new entry to add to flush queue + mutation_entry: new entry to add to flush queue Raises: - - RuntimeError if batcher is closed - - ValueError if an invalid mutation type is added + RuntimeError: if batcher is closed + ValueError: if an invalid mutation type is added """ if self._closed.is_set(): raise RuntimeError("Cannot append to closed MutationsBatcher") @@ -163,8 +165,14 @@ def append(self, mutation_entry: RowMutationEntry): self._schedule_flush() CrossSync._Sync_Impl.yield_to_event_loop() - def _schedule_flush(self) -> CrossSync._Sync_Impl.Future[None] | None: - """Update the flush task to include the latest staged entries""" + def _schedule_flush(self) -> asyncio.Future[None] | None: + """ + Update the flush task to include the latest staged entries + + Returns: + asyncio.Future[None] | None: + future representing the background task, if started + """ if self._staged_entries: entries, self._staged_entries = (self._staged_entries, []) self._staged_count, self._staged_bytes = (0, 0) @@ -182,7 +190,7 @@ def _flush_internal(self, new_entries: list[RowMutationEntry]): Flushes a set of mutations to the server, and updates internal state Args: - - new_entries: list of RowMutationEntry objects to flush + new_entries list of RowMutationEntry objects to flush """ in_process_requests: list[ CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] @@ -203,12 +211,13 @@ def _execute_mutate_rows( Helper to execute mutation operation on a batch Args: - - batch: list of RowMutationEntry objects to send to server - - timeout: timeout in seconds. Used as operation_timeout and attempt_timeout. - If not given, will use table defaults + batch: list of RowMutationEntry objects to send to server + timeout: timeout in seconds. Used as operation_timeout and attempt_timeout. + If not given, will use table defaults Returns: - - list of FailedMutationEntryError objects for mutations that failed. - FailedMutationEntryError objects will not contain index information + list[FailedMutationEntryError]: + list of FailedMutationEntryError objects for mutations that failed. + FailedMutationEntryError objects will not contain index information """ try: operation = _MutateRowsOperation( @@ -233,6 +242,9 @@ def _add_exceptions(self, excs: list[Exception]): Add new list of exceptions to internal store. To avoid unbounded memory, the batcher will store the first and last _exception_list_limit exceptions, and discard any in between. + + Args: + excs: list of exceptions to add to the internal store """ self._exceptions_since_last_raise += len(excs) if excs and len(self._oldest_exceptions) < self._exception_list_limit: @@ -247,7 +259,7 @@ def _raise_exceptions(self): Raise any unreported exceptions from background flush operations Raises: - - MutationsExceptionGroup with all unreported exceptions + MutationsExceptionGroup: exception group with all unreported exceptions """ if self._oldest_exceptions or self._newest_exceptions: oldest, self._oldest_exceptions = (self._oldest_exceptions, []) @@ -269,11 +281,15 @@ def _raise_exceptions(self): ) def __enter__(self): - """For context manager API""" + """Allow use of context manager API""" return self def __exit__(self, exc_type, exc, tb): - """For context manager API""" + """ + Allow use of context manager API. + + Flushes the batcher and cleans up resources. + """ self.close() @property @@ -313,12 +329,13 @@ def _wait_for_batch_results( waits for them to complete, and returns a list of errors encountered. Args: - - *tasks: futures representing _execute_mutate_rows or _flush_internal tasks + *tasks: futures representing _execute_mutate_rows or _flush_internal tasks Returns: - - list of Exceptions encountered by any of the tasks. Errors are expected - to be FailedMutationEntryError, representing a failed mutation operation. - If a task fails with a different exception, it will be included in the - output list. Successful tasks will not be represented in the output list. + list[Exception]: + list of Exceptions encountered by any of the tasks. Errors are expected + to be FailedMutationEntryError, representing a failed mutation operation. + If a task fails with a different exception, it will be included in the + output list. Successful tasks will not be represented in the output list. """ if not tasks: return [] @@ -347,15 +364,16 @@ class _FlowControl: Flow limits are not hard limits. If a single mutation exceeds the configured limits, it will be allowed as a single batch when the capacity is available. + + Args: + max_mutation_count: maximum number of mutations to send in a single rpc. + This corresponds to individual mutations in a single RowMutationEntry. + max_mutation_bytes: maximum number of bytes to send in a single rpc. + Raises: + ValueError: if max_mutation_count or max_mutation_bytes is less than 0 """ def __init__(self, max_mutation_count: int, max_mutation_bytes: int): - """ - Args: - - max_mutation_count: maximum number of mutations to send in a single rpc. - This corresponds to individual mutations in a single RowMutationEntry. - - max_mutation_bytes: maximum number of bytes to send in a single rpc. - """ self._max_mutation_count = max_mutation_count self._max_mutation_bytes = max_mutation_bytes if self._max_mutation_count < 1: @@ -375,10 +393,10 @@ def _has_capacity(self, additional_count: int, additional_size: int) -> bool: previous batches have completed. Args: - - additional_count: number of mutations in the pending entry - - additional_size: size of the pending entry + additional_count: number of mutations in the pending entry + additional_size: size of the pending entry Returns: - - True if there is capacity to send the pending entry, False otherwise + bool: True if there is capacity to send the pending entry, False otherwise """ acceptable_size = max(self._max_mutation_bytes, additional_size) acceptable_count = max(self._max_mutation_count, additional_count) @@ -395,7 +413,7 @@ def remove_from_flow( operation is complete. Args: - - mutations: mutation or list of mutations to remove from flow control + mutations: mutation or list of mutations to remove from flow control """ if not isinstance(mutations, list): mutations = [mutations] @@ -414,10 +432,11 @@ def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): will block until there is capacity available. Args: - - mutations: list mutations to break up into batches + mutations: list mutations to break up into batches Yields: - - list of mutations that have reserved space in the flow control. - Each batch contains at least one mutation. + list[RowMutationEntry]: + list of mutations that have reserved space in the flow control. + Each batch contains at least one mutation. """ if not isinstance(mutations, list): mutations = [mutations] From 0705ee9c3d4852e193e52df58453e47dc4fa37bf Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 17:26:41 -0700 Subject: [PATCH 095/360] got sharding working with cross_sync --- google/cloud/bigtable/data/_async/client.py | 10 +++++++--- google/cloud/bigtable/data/_sync/cross_sync.py | 2 ++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index ef7633d17..39db62cc6 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -821,7 +821,7 @@ async def read_rows_sharded( ) # limit the number of concurrent requests using a semaphore - concurrency_sem = asyncio.Semaphore(_helpers._CONCURRENCY_LIMIT) + concurrency_sem = CrossSync.Semaphore(_helpers._CONCURRENCY_LIMIT) async def read_rows_with_semaphore(query): async with concurrency_sem: @@ -838,8 +838,12 @@ async def read_rows_with_semaphore(query): retryable_errors=retryable_errors, ) - routine_list = [read_rows_with_semaphore(query) for query in sharded_query] - batch_result = await asyncio.gather(*routine_list, return_exceptions=True) + routine_list = [partial(read_rows_with_semaphore, query) for query in sharded_query] + batch_result = await CrossSync.gather_partials( + routine_list, + return_exceptions=True, + sync_executor=self.client._executor + ) # collect results and errors error_dict = {} diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 0281e29a0..333730114 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -42,6 +42,7 @@ class CrossSync: Future: TypeAlias = asyncio.Future Task: TypeAlias = asyncio.Task Event: TypeAlias = asyncio.Event + Semaphore: TypeAlias = asyncio.Semaphore generated_replacements: dict[type, str] = {} @@ -192,6 +193,7 @@ class _Sync_Impl: Future: TypeAlias = concurrent.futures.Future Task: TypeAlias = concurrent.futures.Future Event: TypeAlias = threading.Event + Semaphore: TypeAlias = threading.Semaphore generated_replacements: dict[type, str] = {} From 426057fdb288397f6f870efc55d25ba432ccc661 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 17:31:33 -0700 Subject: [PATCH 096/360] fixed mypy errors --- .../cloud/bigtable/data/_async/_read_rows.py | 4 ++-- google/cloud/bigtable/data/_async/client.py | 4 ++-- .../bigtable/data/_async/mutations_batcher.py | 6 ++---- google/cloud/bigtable/data/_sync/_read_rows.py | 4 ++-- google/cloud/bigtable/data/_sync/client.py | 18 +++++++++++------- .../bigtable/data/_sync/mutations_batcher.py | 6 ++---- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 691df860a..4bd9eafcc 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -226,8 +226,8 @@ async def chunk_stream( @staticmethod async def merge_rows( - chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None - ) -> AsyncGenerator[Row, None]: + chunks: AsyncIterable[ReadRowsResponsePB.CellChunk] | None + ) -> AsyncIterable[Row]: """ Merge chunks into rows diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 39db62cc6..b09f607e4 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -119,8 +119,6 @@ "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", "BigtableAsyncClient": "BigtableClient", "AsyncPooledChannel": "PooledChannel", - "_ReadRowsOperationAsync": "_ReadRowsOperation", - "_MutateRowsOperationAsync": "_MutateRowsOperation", }, ) class BigtableDataClientAsync(ClientWithProject): @@ -495,6 +493,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): "AsyncIterable": "Iterable", "MutationsBatcherAsync": "MutationsBatcher", "BigtableDataClientAsync": "BigtableDataClient", + "_ReadRowsOperationAsync": "_ReadRowsOperation", + "_MutateRowsOperationAsync": "_MutateRowsOperation", }, ) class TableAsync: diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 16f348406..3cd17b458 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -281,8 +281,6 @@ async def _timer_routine(self, interval: float | None) -> None: Args: flush_interval: Automatically flush every flush_interval seconds. If None, no time-based flushing is performed. - Returns: - asyncio.Future[None]: future representing the background task """ if not interval or interval <= 0: return None @@ -321,12 +319,12 @@ async def append(self, mutation_entry: RowMutationEntry): # yield to the event loop to allow flush to run await CrossSync.yield_to_event_loop() - def _schedule_flush(self) -> asyncio.Future[None] | None: + def _schedule_flush(self) -> CrossSync.Future[None] | None: """ Update the flush task to include the latest staged entries Returns: - asyncio.Future[None] | None: + Future[None] | None: future representing the background task, if started """ if self._staged_entries: diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index d43782a8f..9e702e1d6 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -183,8 +183,8 @@ def chunk_stream( @staticmethod def merge_rows( - chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None - ) -> AsyncGenerator[Row, None]: + chunks: Iterable[ReadRowsResponsePB.CellChunk] | None, + ) -> Iterable[Row]: """ Merge chunks into rows diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 741ea603d..bf2193fd4 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -38,12 +38,12 @@ from google.api_core.exceptions import ServiceUnavailable from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT from google.cloud.bigtable.data import _helpers -from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync -from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import _MB_SIZE +from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation +from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation from google.cloud.bigtable.data._sync.cross_sync import CrossSync from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher from google.cloud.bigtable.data.exceptions import FailedQueryShardError @@ -572,7 +572,7 @@ def read_rows_stream( operation_timeout, attempt_timeout, self ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - row_merger = _ReadRowsOperationAsync( + row_merger = _ReadRowsOperation( query, self, operation_timeout=operation_timeout, @@ -724,7 +724,7 @@ def read_rows_sharded( rpc_timeout_generator = _helpers._attempt_timeout_generator( operation_timeout, operation_timeout ) - concurrency_sem = asyncio.Semaphore(_helpers._CONCURRENCY_LIMIT) + concurrency_sem = CrossSync._Sync_Impl.Semaphore(_helpers._CONCURRENCY_LIMIT) def read_rows_with_semaphore(query): with concurrency_sem: @@ -740,8 +740,12 @@ def read_rows_with_semaphore(query): retryable_errors=retryable_errors, ) - routine_list = [read_rows_with_semaphore(query) for query in sharded_query] - batch_result = asyncio.gather(*routine_list, return_exceptions=True) + routine_list = [ + partial(read_rows_with_semaphore, query) for query in sharded_query + ] + batch_result = CrossSync._Sync_Impl.gather_partials( + routine_list, return_exceptions=True, sync_executor=self.client._executor + ) error_dict = {} shard_idx = 0 results_list = [] @@ -1043,7 +1047,7 @@ def bulk_mutate_rows( operation_timeout, attempt_timeout, self ) retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - operation = _MutateRowsOperationAsync( + operation = _MutateRowsOperation( self.client._gapic_client, self, mutation_entries, diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 8a38af39c..f3644e62b 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -127,8 +127,6 @@ def _timer_routine(self, interval: float | None) -> None: Args: flush_interval: Automatically flush every flush_interval seconds. If None, no time-based flushing is performed. - Returns: - asyncio.Future[None]: future representing the background task """ if not interval or interval <= 0: return None @@ -165,12 +163,12 @@ def append(self, mutation_entry: RowMutationEntry): self._schedule_flush() CrossSync._Sync_Impl.yield_to_event_loop() - def _schedule_flush(self) -> asyncio.Future[None] | None: + def _schedule_flush(self) -> CrossSync._Sync_Impl.Future[None] | None: """ Update the flush task to include the latest staged entries Returns: - asyncio.Future[None] | None: + Future[None] | None: future representing the background task, if started """ if self._staged_entries: From 0c79e39074d8010aa999ecb127e075c75a15724f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 21 Jun 2024 17:32:24 -0700 Subject: [PATCH 097/360] ran blacken --- google/cloud/bigtable/data/_async/_read_rows.py | 2 +- google/cloud/bigtable/data/_async/client.py | 8 ++++---- google/cloud/bigtable/data/_async/mutations_batcher.py | 4 +++- google/cloud/bigtable/data/_sync/cross_sync.py | 2 +- tests/unit/data/_async/test_mutations_batcher.py | 1 + 5 files changed, 10 insertions(+), 7 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 4bd9eafcc..d46b9ec6a 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -226,7 +226,7 @@ async def chunk_stream( @staticmethod async def merge_rows( - chunks: AsyncIterable[ReadRowsResponsePB.CellChunk] | None + chunks: AsyncIterable[ReadRowsResponsePB.CellChunk] | None, ) -> AsyncIterable[Row]: """ Merge chunks into rows diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index b09f607e4..00a3ee419 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -838,11 +838,11 @@ async def read_rows_with_semaphore(query): retryable_errors=retryable_errors, ) - routine_list = [partial(read_rows_with_semaphore, query) for query in sharded_query] + routine_list = [ + partial(read_rows_with_semaphore, query) for query in sharded_query + ] batch_result = await CrossSync.gather_partials( - routine_list, - return_exceptions=True, - sync_executor=self.client._executor + routine_list, return_exceptions=True, sync_executor=self.client._executor ) # collect results and errors diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 3cd17b458..643e90126 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -286,7 +286,9 @@ async def _timer_routine(self, interval: float | None) -> None: return None while not self._closed.is_set(): # wait until interval has passed, or until closed - await CrossSync.event_wait(self._closed, timeout=interval, async_break_early=False) + await CrossSync.event_wait( + self._closed, timeout=interval, async_break_early=False + ) if not self._closed.is_set() and self._staged_entries: self._schedule_flush() diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 333730114..3b5f7ef1a 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -57,7 +57,7 @@ def decorator(func): def sync_output( cls, sync_path: str, - replace_symbols: dict["str", "str" | None ] | None = None, + replace_symbols: dict["str", "str" | None] | None = None, mypy_ignore: list[str] | None = None, ): replace_symbols = replace_symbols or {} diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index b76eee300..83dec7097 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -553,6 +553,7 @@ async def test__start_flush_timer_call_when_closed( async def test__flush_timer(self, num_staged): """Timer should continue to call _schedule_flush in a loop""" from google.cloud.bigtable.data._sync.cross_sync import CrossSync + with mock.patch.object( self._get_target_class(), "_schedule_flush" ) as flush_mock: From 4aa53eb3fbb7143d144e7209590af82e041f4e89 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 24 Jun 2024 14:01:54 -0700 Subject: [PATCH 098/360] generate sync unit tests --- sync_surface_generator.py | 7 +- tests/unit/data/_async/test__mutate_rows.py | 27 +- tests/unit/data/_async/test__read_rows.py | 5 + tests/unit/data/_async/test_client.py | 31 +- .../data/_async/test_mutations_batcher.py | 9 +- .../data/_async/test_read_rows_acceptance.py | 5 + tests/unit/data/_sync/test__mutate_rows.py | 324 ++ tests/unit/data/_sync/test__read_rows.py | 363 +++ tests/unit/data/_sync/test_client.py | 2779 +++++++++++++++++ .../unit/data/_sync/test_mutations_batcher.py | 1125 +++++++ .../data/_sync/test_read_rows_acceptance.py | 328 ++ 11 files changed, 4991 insertions(+), 12 deletions(-) create mode 100644 tests/unit/data/_sync/test__mutate_rows.py create mode 100644 tests/unit/data/_sync/test__read_rows.py create mode 100644 tests/unit/data/_sync/test_client.py create mode 100644 tests/unit/data/_sync/test_mutations_batcher.py create mode 100644 tests/unit/data/_sync/test_read_rows_acceptance.py diff --git a/sync_surface_generator.py b/sync_surface_generator.py index efd964ab4..25ba85a0d 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -344,8 +344,13 @@ def transform_class(in_obj: Type, **kwargs): # find all classes in the library lib_root = "google/cloud/bigtable/data/_async" lib_files = [f"{lib_root}/{f}" for f in os.listdir(lib_root) if f.endswith(".py")] + + test_root = "tests/unit/data/_async" + test_files = [f"{test_root}/{f}" for f in os.listdir(test_root) if f.endswith(".py")] + all_files = lib_files + test_files + enabled_classes = [] - for file in lib_files: + for file in all_files: file_module = file.replace("/", ".")[:-3] for cls_name, cls in inspect.getmembers(importlib.import_module(file_module), inspect.isclass): # keep only those with CrossSync annotation diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 26a9325f0..03b7db3f4 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -16,7 +16,10 @@ from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 -import google.api_core.exceptions as core_exceptions +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import Forbidden + +from google.cloud.bigtable.data._sync.cross_sync import CrossSync # try/except added for compatibility with python < 3.8 try: @@ -27,13 +30,21 @@ from mock import AsyncMock # type: ignore +@CrossSync.sync_output( + "tests.unit.data._sync.test__mutate_rows.TestMutateRowsOperation", +) class TestMutateRowsOperation: def _target_class(self): - from google.cloud.bigtable.data._async._mutate_rows import ( - _MutateRowsOperationAsync, - ) + if CrossSync.is_async: + from google.cloud.bigtable.data._async._mutate_rows import ( + _MutateRowsOperationAsync, + ) + + return _MutateRowsOperationAsync + else: + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation - return _MutateRowsOperationAsync + return _MutateRowsOperation def _make_one(self, *args, **kwargs): if not args: @@ -171,7 +182,7 @@ async def test_mutate_rows_operation(self): assert attempt_mock.call_count == 1 @pytest.mark.parametrize( - "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] + "exc_type", [RuntimeError, ZeroDivisionError, Forbidden] ) @pytest.mark.asyncio async def test_mutate_rows_attempt_exception(self, exc_type): @@ -199,7 +210,7 @@ async def test_mutate_rows_attempt_exception(self, exc_type): assert len(instance.remaining_indices) == 0 @pytest.mark.parametrize( - "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] + "exc_type", [RuntimeError, ZeroDivisionError, Forbidden] ) @pytest.mark.asyncio async def test_mutate_rows_exception(self, exc_type): @@ -237,7 +248,7 @@ async def test_mutate_rows_exception(self, exc_type): @pytest.mark.parametrize( "exc_type", - [core_exceptions.DeadlineExceeded, RuntimeError], + [DeadlineExceeded, RuntimeError], ) @pytest.mark.asyncio async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index d9d12a729..abda3af05 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -13,6 +13,8 @@ import pytest +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + # try/except added for compatibility with python < 3.8 try: from unittest import mock @@ -27,6 +29,9 @@ TEST_LABELS = ["label1", "label2"] +@CrossSync.sync_output( + "tests.unit.data._sync.test__read_rows.TestReadRowsOperation", +) class TestReadRowsOperation: """ Tests helper functions in the ReadRowsOperation class diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index ccd71046c..8af98c376 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -42,6 +42,8 @@ from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + # try/except added for compatibility with python < 3.8 try: from unittest import mock @@ -51,6 +53,9 @@ from mock import AsyncMock # type: ignore +@CrossSync.sync_output( + "tests.unit.data._sync.test_client.TestBigtableDataClient", +) class TestBigtableDataClientAsync: @staticmethod def _get_target_class(): @@ -1090,7 +1095,9 @@ def test_client_ctor_sync(self): assert client.project == "project-id" assert client._channel_refresh_tasks == [] - +@CrossSync.sync_output( + "tests.unit.data._sync.test_client.TestTable", +) class TestTableAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -1409,7 +1416,9 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ else: assert "app_profile_id=" not in goog_metadata - +@CrossSync.sync_output( + "tests.unit.data._sync.test_client.TestReadRows", +) class TestReadRowsAsync: """ Tests for table.read_rows and related methods. @@ -1916,6 +1925,9 @@ async def test_row_exists(self, return_value, expected_result): assert query.filter._to_dict() == expected_filter +@CrossSync.sync_output( + "tests.unit.data._sync.test_client.TestReadRowsSharded", +) class TestReadRowsShardedAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2136,6 +2148,9 @@ async def mock_call(*args, **kwargs): ) +@CrossSync.sync_output( + "tests.unit.data._sync.test_client.TestSampleRowKeys", +) class TestSampleRowKeysAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2287,6 +2302,9 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio await table.sample_row_keys() +@CrossSync.sync_output( + "tests.unit.data._sync.test_client.TestMutateRow", +) class TestMutateRowAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2462,6 +2480,9 @@ async def test_mutate_row_no_mutations(self, mutations): assert e.value.args[0] == "No mutations provided" +@CrossSync.sync_output( + "tests.unit.data._sync.test_client.TestBulkMutateRows", +) class TestBulkMutateRowsAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2841,6 +2862,9 @@ async def test_bulk_mutate_error_recovery(self): await table.bulk_mutate_rows(entries, operation_timeout=1000) +@CrossSync.sync_output( + "tests.unit.data._sync.test_client.TestCheckAndMutateRow", +) class TestCheckAndMutateRowAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2992,6 +3016,9 @@ async def test_check_and_mutate_mutations_parsing(self): ) +@CrossSync.sync_output( + "tests.unit.data._sync.test_client.TestReadModifyWriteRow", +) class TestReadModifyWriteRowAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 83dec7097..d7f5b68cf 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -21,6 +21,8 @@ from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data import TableAsync +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + # try/except added for compatibility with python < 3.8 try: from unittest import mock @@ -30,6 +32,9 @@ from mock import AsyncMock # type: ignore +@CrossSync.sync_output( + "tests.unit.data._sync.test_mutations_batcher.Test_FlowControl" +) class Test_FlowControl: @staticmethod def _target_class(): @@ -306,7 +311,9 @@ async def test_add_to_flow_oversize(self): ] assert len(count_results) == 1 - +@CrossSync.sync_output( + "tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" +) class TestMutationsBatcherAsync: def _get_target_class(self): from google.cloud.bigtable.data._async.mutations_batcher import ( diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 600c10b3b..7434e20af 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -27,7 +27,12 @@ from tests.unit.v2_client.test_row_merger import ReadRowsTest, TestFile +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +@CrossSync.sync_output( + "tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", +) class TestReadRowsAcceptanceAsync: @staticmethod def _get_operation_class(): diff --git a/tests/unit/data/_sync/test__mutate_rows.py b/tests/unit/data/_sync/test__mutate_rows.py new file mode 100644 index 000000000..bbe6fbb22 --- /dev/null +++ b/tests/unit/data/_sync/test__mutate_rows.py @@ -0,0 +1,324 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by sync_surface_generator.py. Do not edit. + + +from mock import AsyncMock +# from tests.unit.data._async.test__mutate_rows import TestMutateRowsOperation +from unittest import mock +from unittest.mock import AsyncMock +import mock +import pytest + +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import Forbidden +from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable_v2.types import MutateRowsResponse +from google.rpc import status_pb2 + + +class TestMutateRowsOperation: + def _target_class(self): + if CrossSync._Sync_Impl.is_async: + from google.cloud.bigtable.data._async._mutate_rows import ( + _MutateRowsOperationAsync, + ) + + return _MutateRowsOperationAsync + else: + from google.cloud.bigtable.data._sync._mutate_rows import ( + _MutateRowsOperation, + ) + + return _MutateRowsOperation + + def _make_one(self, *args, **kwargs): + if not args: + kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) + kwargs["table"] = kwargs.pop("table", AsyncMock()) + kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) + kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) + kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) + kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) + return self._target_class()(*args, **kwargs) + + def _make_mutation(self, count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def _mock_stream(self, mutation_list, error_dict): + for idx, entry in enumerate(mutation_list): + code = error_dict.get(idx, 0) + yield MutateRowsResponse( + entries=[ + MutateRowsResponse.Entry( + index=idx, status=status_pb2.Status(code=code) + ) + ] + ) + + def _make_mock_gapic(self, mutation_list, error_dict=None): + mock_fn = AsyncMock() + if error_dict is None: + error_dict = {} + mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( + mutation_list, error_dict + ) + return mock_fn + + def test_ctor(self): + """test that constructor sets all the attributes correctly""" + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + from google.api_core.exceptions import DeadlineExceeded + from google.api_core.exceptions import Aborted + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + attempt_timeout = 0.01 + retryable_exceptions = () + instance = self._make_one( + client, + table, + entries, + operation_timeout, + attempt_timeout, + retryable_exceptions, + ) + assert client.mutate_rows.call_count == 0 + instance._gapic_fn() + assert client.mutate_rows.call_count == 1 + inner_kwargs = client.mutate_rows.call_args[1] + assert len(inner_kwargs) == 4 + assert inner_kwargs["table_name"] == table.table_name + assert inner_kwargs["app_profile_id"] == table.app_profile_id + assert inner_kwargs["retry"] is None + metadata = inner_kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert str(table.table_name) in metadata[0][1] + assert str(table.app_profile_id) in metadata[0][1] + entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries] + assert instance.mutations == entries_w_pb + assert next(instance.timeout_generator) == attempt_timeout + assert instance.is_retryable is not None + assert instance.is_retryable(DeadlineExceeded("")) is False + assert instance.is_retryable(Aborted("")) is False + assert instance.is_retryable(_MutateRowsIncomplete("")) is True + assert instance.is_retryable(RuntimeError("")) is False + assert instance.remaining_indices == list(range(len(entries))) + assert instance.errors == {} + + def test_ctor_too_many_entries(self): + """should raise an error if an operation is created with more than 100,000 entries""" + from google.cloud.bigtable.data._async._mutate_rows import ( + _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, + ) + + assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100000 + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] * (_MUTATE_ROWS_REQUEST_MUTATION_LIMIT + 1) + operation_timeout = 0.05 + attempt_timeout = 0.01 + with pytest.raises(ValueError) as e: + self._make_one(client, table, entries, operation_timeout, attempt_timeout) + assert "mutate_rows requests can contain at most 100000 mutations" in str( + e.value + ) + assert "Found 100001" in str(e.value) + + def test_mutate_rows_operation(self): + """Test successful case of mutate_rows_operation""" + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + cls = self._target_class() + with mock.patch( + f"{cls.__module__}.{cls.__name__}._run_attempt", AsyncMock() + ) as attempt_mock: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + assert attempt_mock.call_count == 1 + + @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) + def test_mutate_rows_attempt_exception(self, exc_type): + """exceptions raised from attempt should be raised in MutationsExceptionGroup""" + client = AsyncMock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + expected_exception = exc_type("test") + client.mutate_rows.side_effect = expected_exception + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance._run_attempt() + except Exception as e: + found_exc = e + assert client.mutate_rows.call_count == 1 + assert type(found_exc) is exc_type + assert found_exc == expected_exception + assert len(instance.errors) == 2 + assert len(instance.remaining_indices) == 0 + + @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) + def test_mutate_rows_exception(self, exc_type): + """exceptions raised from retryable should be raised in MutationsExceptionGroup""" + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + expected_cause = exc_type("abort") + with mock.patch.object( + self._target_class(), "_run_attempt", AsyncMock() + ) as attempt_mock: + attempt_mock.side_effect = expected_cause + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + except MutationsExceptionGroup as e: + found_exc = e + assert attempt_mock.call_count == 1 + assert len(found_exc.exceptions) == 2 + assert isinstance(found_exc.exceptions[0], FailedMutationEntryError) + assert isinstance(found_exc.exceptions[1], FailedMutationEntryError) + assert found_exc.exceptions[0].__cause__ == expected_cause + assert found_exc.exceptions[1].__cause__ == expected_cause + + @pytest.mark.parametrize("exc_type", [DeadlineExceeded, RuntimeError]) + def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): + """If an exception fails but eventually passes, it should not raise an exception""" + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] + operation_timeout = 1 + expected_cause = exc_type("retry") + num_retries = 2 + with mock.patch.object( + self._target_class(), "_run_attempt", AsyncMock() + ) as attempt_mock: + attempt_mock.side_effect = [expected_cause] * num_retries + [None] + instance = self._make_one( + client, + table, + entries, + operation_timeout, + operation_timeout, + retryable_exceptions=(exc_type,), + ) + instance.start() + assert attempt_mock.call_count == num_retries + 1 + + def test_mutate_rows_incomplete_ignored(self): + """MutateRowsIncomplete exceptions should not be added to error list""" + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] + operation_timeout = 0.05 + with mock.patch.object( + self._target_class(), "_run_attempt", AsyncMock() + ) as attempt_mock: + attempt_mock.side_effect = _MutateRowsIncomplete("ignored") + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + except MutationsExceptionGroup as e: + found_exc = e + assert attempt_mock.call_count > 0 + assert len(found_exc.exceptions) == 1 + assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) + + def test_run_attempt_single_entry_success(self): + """Test mutating a single entry""" + mutation = self._make_mutation() + expected_timeout = 1.3 + mock_gapic_fn = self._make_mock_gapic({0: mutation}) + instance = self._make_one( + mutation_entries=[mutation], attempt_timeout=expected_timeout + ) + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + instance._run_attempt() + assert len(instance.remaining_indices) == 0 + assert mock_gapic_fn.call_count == 1 + _, kwargs = mock_gapic_fn.call_args + assert kwargs["timeout"] == expected_timeout + assert kwargs["entries"] == [mutation._to_pb()] + + def test_run_attempt_empty_request(self): + """Calling with no mutations should result in no API calls""" + mock_gapic_fn = self._make_mock_gapic([]) + instance = self._make_one(mutation_entries=[]) + instance._run_attempt() + assert mock_gapic_fn.call_count == 0 + + def test_run_attempt_partial_success_retryable(self): + """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() + mutations = [success_mutation, failure_mutation, success_mutation_2] + mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) + instance = self._make_one(mutation_entries=mutations) + instance.is_retryable = lambda x: True + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + with pytest.raises(_MutateRowsIncomplete): + instance._run_attempt() + assert instance.remaining_indices == [1] + assert 0 not in instance.errors + assert len(instance.errors[1]) == 1 + assert instance.errors[1][0].grpc_status_code == 300 + assert 2 not in instance.errors + + def test_run_attempt_partial_success_non_retryable(self): + """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() + mutations = [success_mutation, failure_mutation, success_mutation_2] + mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) + instance = self._make_one(mutation_entries=mutations) + instance.is_retryable = lambda x: False + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + instance._run_attempt() + assert instance.remaining_indices == [] + assert 0 not in instance.errors + assert len(instance.errors[1]) == 1 + assert instance.errors[1][0].grpc_status_code == 300 + assert 2 not in instance.errors diff --git a/tests/unit/data/_sync/test__read_rows.py b/tests/unit/data/_sync/test__read_rows.py new file mode 100644 index 000000000..0cda06432 --- /dev/null +++ b/tests/unit/data/_sync/test__read_rows.py @@ -0,0 +1,363 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by sync_surface_generator.py. Do not edit. + + +from tests.unit.data._async.test__read_rows import TestReadRowsOperation +from unittest import mock +import mock +import pytest + + +class TestReadRowsOperation: + """ + Tests helper functions in the ReadRowsOperation class + in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt + is tested in test_read_rows_acceptance test_client_read_rows, and conformance tests + """ + + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + + return _ReadRowsOperationAsync + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor(self): + from google.cloud.bigtable.data import ReadRowsQuery + + row_limit = 91 + query = ReadRowsQuery(limit=row_limit) + client = mock.Mock() + client.read_rows = mock.Mock() + client.read_rows.return_value = None + table = mock.Mock() + table._client = client + table.table_name = "test_table" + table.app_profile_id = "test_profile" + expected_operation_timeout = 42 + expected_request_timeout = 44 + time_gen_mock = mock.Mock() + with mock.patch( + "google.cloud.bigtable.data._helpers._attempt_timeout_generator", + time_gen_mock, + ): + instance = self._make_one( + query, + table, + operation_timeout=expected_operation_timeout, + attempt_timeout=expected_request_timeout, + ) + assert time_gen_mock.call_count == 1 + time_gen_mock.assert_called_once_with( + expected_request_timeout, expected_operation_timeout + ) + assert instance._last_yielded_row_key is None + assert instance._remaining_count == row_limit + assert instance.operation_timeout == expected_operation_timeout + assert client.read_rows.call_count == 0 + assert instance._metadata == [ + ( + "x-goog-request-params", + "table_name=test_table&app_profile_id=test_profile", + ) + ] + assert instance.request.table_name == table.table_name + assert instance.request.app_profile_id == table.app_profile_id + assert instance.request.rows_limit == row_limit + + @pytest.mark.parametrize( + "in_keys,last_key,expected", + [ + (["b", "c", "d"], "a", ["b", "c", "d"]), + (["a", "b", "c"], "b", ["c"]), + (["a", "b", "c"], "c", []), + (["a", "b", "c"], "d", []), + (["d", "c", "b", "a"], "b", ["d", "c"]), + ], + ) + @pytest.mark.parametrize("with_range", [True, False]) + def test_revise_request_rowset_keys_with_range( + self, in_keys, last_key, expected, with_range + ): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + from google.cloud.bigtable.data.exceptions import _RowSetComplete + + in_keys = [key.encode("utf-8") for key in in_keys] + expected = [key.encode("utf-8") for key in expected] + last_key = last_key.encode("utf-8") + if with_range: + sample_range = [RowRangePB(start_key_open=last_key)] + else: + sample_range = [] + row_set = RowSetPB(row_keys=in_keys, row_ranges=sample_range) + if not with_range and expected == []: + with pytest.raises(_RowSetComplete): + self._get_target_class()._revise_request_rowset(row_set, last_key) + else: + revised = self._get_target_class()._revise_request_rowset(row_set, last_key) + assert revised.row_keys == expected + assert revised.row_ranges == sample_range + + @pytest.mark.parametrize( + "in_ranges,last_key,expected", + [ + ( + [{"start_key_open": "b", "end_key_closed": "d"}], + "a", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_closed": "b", "end_key_closed": "d"}], + "a", + [{"start_key_closed": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_open": "a", "end_key_closed": "d"}], + "b", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_closed": "a", "end_key_open": "d"}], + "b", + [{"start_key_open": "b", "end_key_open": "d"}], + ), + ( + [{"start_key_closed": "b", "end_key_closed": "d"}], + "b", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ([{"start_key_closed": "b", "end_key_closed": "d"}], "d", []), + ([{"start_key_closed": "b", "end_key_open": "d"}], "d", []), + ([{"start_key_closed": "b", "end_key_closed": "d"}], "e", []), + ([{"start_key_closed": "b"}], "z", [{"start_key_open": "z"}]), + ([{"start_key_closed": "b"}], "a", [{"start_key_closed": "b"}]), + ( + [{"end_key_closed": "z"}], + "a", + [{"start_key_open": "a", "end_key_closed": "z"}], + ), + ( + [{"end_key_open": "z"}], + "a", + [{"start_key_open": "a", "end_key_open": "z"}], + ), + ], + ) + @pytest.mark.parametrize("with_key", [True, False]) + def test_revise_request_rowset_ranges( + self, in_ranges, last_key, expected, with_key + ): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + from google.cloud.bigtable.data.exceptions import _RowSetComplete + + next_key = (last_key + "a").encode("utf-8") + last_key = last_key.encode("utf-8") + in_ranges = [ + RowRangePB(**{k: v.encode("utf-8") for k, v in r.items()}) + for r in in_ranges + ] + expected = [ + RowRangePB(**{k: v.encode("utf-8") for k, v in r.items()}) for r in expected + ] + if with_key: + row_keys = [next_key] + else: + row_keys = [] + row_set = RowSetPB(row_ranges=in_ranges, row_keys=row_keys) + if not with_key and expected == []: + with pytest.raises(_RowSetComplete): + self._get_target_class()._revise_request_rowset(row_set, last_key) + else: + revised = self._get_target_class()._revise_request_rowset(row_set, last_key) + assert revised.row_keys == row_keys + assert revised.row_ranges == expected + + @pytest.mark.parametrize("last_key", ["a", "b", "c"]) + def test_revise_request_full_table(self, last_key): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + last_key = last_key.encode("utf-8") + row_set = RowSetPB() + for selected_set in [row_set, None]: + revised = self._get_target_class()._revise_request_rowset( + selected_set, last_key + ) + assert revised.row_keys == [] + assert len(revised.row_ranges) == 1 + assert revised.row_ranges[0] == RowRangePB(start_key_open=last_key) + + def test_revise_to_empty_rowset(self): + """revising to an empty rowset should raise error""" + from google.cloud.bigtable.data.exceptions import _RowSetComplete + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + row_keys = [b"a", b"b", b"c"] + row_range = RowRangePB(end_key_open=b"c") + row_set = RowSetPB(row_keys=row_keys, row_ranges=[row_range]) + with pytest.raises(_RowSetComplete): + self._get_target_class()._revise_request_rowset(row_set, b"d") + + @pytest.mark.parametrize( + "start_limit,emit_num,expected_limit", + [ + (10, 0, 10), + (10, 1, 9), + (10, 10, 0), + (None, 10, None), + (None, 0, None), + (4, 2, 2), + ], + ) + def test_revise_limit(self, start_limit, emit_num, expected_limit): + """ + revise_limit should revise the request's limit field + - if limit is 0 (unlimited), it should never be revised + - if start_limit-emit_num == 0, the request should end early + - if the number emitted exceeds the new limit, an exception should + should be raised (tested in test_revise_limit_over_limit) + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse + + def awaitable_stream(): + def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) + + return mock_stream() + + query = ReadRowsQuery(limit=start_limit) + table = mock.Mock() + table.table_name = "table_name" + table.app_profile_id = "app_profile_id" + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + for val in instance.chunk_stream(awaitable_stream()): + pass + assert instance._remaining_count == expected_limit + + @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) + def test_revise_limit_over_limit(self, start_limit, emit_num): + """ + Should raise runtime error if we get in state where emit_num > start_num + (unless start_num == 0, which represents unlimited) + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse + from google.cloud.bigtable.data.exceptions import InvalidChunk + + def awaitable_stream(): + def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) + + return mock_stream() + + query = ReadRowsQuery(limit=start_limit) + table = mock.Mock() + table.table_name = "table_name" + table.app_profile_id = "app_profile_id" + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + with pytest.raises(InvalidChunk) as e: + for val in instance.chunk_stream(awaitable_stream()): + pass + assert "emit count exceeds row limit" in str(e.value) + + def test_aclose(self): + """ + should be able to close a stream safely with aclose. + Closed generators should raise StopAsyncIteration on next yield + """ + + def mock_stream(): + while True: + yield 1 + + with mock.patch.object( + self._get_target_class(), "_read_rows_attempt" + ) as mock_attempt: + instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) + wrapped_gen = mock_stream() + mock_attempt.return_value = wrapped_gen + gen = instance.start_operation() + gen.__anext__() + gen.aclose() + with pytest.raises(StopAsyncIteration): + gen.__anext__() + gen.aclose() + with pytest.raises(StopAsyncIteration): + wrapped_gen.__anext__() + + def test_retryable_ignore_repeated_rows(self): + """Duplicate rows should cause an invalid chunk error""" + from google.cloud.bigtable.data.exceptions import InvalidChunk + from google.cloud.bigtable_v2.types import ReadRowsResponse + + row_key = b"duplicate" + + def mock_awaitable_stream(): + def mock_stream(): + while True: + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) + ] + ) + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) + ] + ) + + return mock_stream() + + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + stream = self._get_target_class().chunk_stream( + instance, mock_awaitable_stream() + ) + stream.__anext__() + with pytest.raises(InvalidChunk) as exc: + stream.__anext__() + assert "row keys should be strictly increasing" in str(exc.value) diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py new file mode 100644 index 000000000..4feafd373 --- /dev/null +++ b/tests/unit/data/_sync/test_client.py @@ -0,0 +1,2779 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by sync_surface_generator.py. Do not edit. + + +from __future__ import annotations +from mock import AsyncMock +from tests.unit.data._async.test_client import TestBigtableDataClientAsync +from tests.unit.data._async.test_client import TestReadRowsAsync +from tests.unit.data._async.test_client import TestTableAsync +from unittest import mock +from unittest.mock import AsyncMock +import asyncio +import grpc +import mock +import pytest +import re +import sys + +from google.api_core import exceptions as core_exceptions +from google.api_core import grpc_helpers_async +from google.auth.credentials import AnonymousCredentials +from google.cloud.bigtable.data import TABLE_DEFAULT +from google.cloud.bigtable.data import mutations +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule +from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable_v2 import ReadRowsResponse +from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, +) +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel as PooledChannelAsync, +) +from google.cloud.bigtable_v2.types import ReadRowsResponse + + +class TestBigtableDataClient: + @staticmethod + def _get_target_class(): + if CrossSync.is_async: + from google.cloud.bigtable.data._async.client import BigtableDataClientAsync + return BigtableDataClientAsync + else: + from google.cloud.bigtable.data._sync.client import BigtableDataClient + return BigtableDataClient + + @classmethod + def _make_client(cls, *args, use_emulator=True, **kwargs): + import os + + env_mask = {} + if use_emulator: + env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" + import warnings + + warnings.filterwarnings("ignore", category=RuntimeWarning) + else: + kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) + kwargs["project"] = kwargs.get("project", "project-id") + with mock.patch.dict(os.environ, env_mask): + return cls._get_target_class()(*args, **kwargs) + + @property + def is_async(self): + return True + + def test_ctor(self): + expected_project = "project-id" + expected_pool_size = 11 + expected_credentials = AnonymousCredentials() + client = self._make_client( + project="project-id", + pool_size=expected_pool_size, + credentials=expected_credentials, + use_emulator=False, + ) + asyncio.sleep(0) + assert client.project == expected_project + assert len(client.transport._grpc_channel._pool) == expected_pool_size + assert not client._active_instances + assert len(client._channel_refresh_tasks) == expected_pool_size + assert client.transport._credentials == expected_credentials + client.close() + + def test_ctor_super_inits(self): + from google.cloud.client import ClientWithProject + from google.api_core import client_options as client_options_lib + from google.cloud.bigtable import __version__ as bigtable_version + + project = "project-id" + pool_size = 11 + credentials = AnonymousCredentials() + client_options = {"api_endpoint": "foo.bar:1234"} + options_parsed = client_options_lib.from_dict(client_options) + asyncio_portion = "-async" if self.is_async else "" + transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" + with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: + bigtable_client_init.return_value = None + with mock.patch.object( + ClientWithProject, "__init__" + ) as client_project_init: + client_project_init.return_value = None + try: + self._make_client( + project=project, + pool_size=pool_size, + credentials=credentials, + client_options=options_parsed, + use_emulator=False, + ) + except AttributeError: + pass + assert bigtable_client_init.call_count == 1 + kwargs = bigtable_client_init.call_args[1] + assert kwargs["transport"] == transport_str + assert kwargs["credentials"] == credentials + assert kwargs["client_options"] == options_parsed + assert client_project_init.call_count == 1 + kwargs = client_project_init.call_args[1] + assert kwargs["project"] == project + assert kwargs["credentials"] == credentials + assert kwargs["client_options"] == options_parsed + + def test_ctor_dict_options(self): + from google.api_core.client_options import ClientOptions + + client_options = {"api_endpoint": "foo.bar:1234"} + with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: + try: + self._make_client(client_options=client_options) + except TypeError: + pass + bigtable_client_init.assert_called_once() + kwargs = bigtable_client_init.call_args[1] + called_options = kwargs["client_options"] + assert called_options.api_endpoint == "foo.bar:1234" + assert isinstance(called_options, ClientOptions) + with mock.patch.object( + self._get_target_class(), "_start_background_channel_refresh" + ) as start_background_refresh: + client = self._make_client( + client_options=client_options, use_emulator=False + ) + start_background_refresh.assert_called_once() + client.close() + + def test_veneer_grpc_headers(self): + client_component = "data-async" if self.is_async else "data" + VENEER_HEADER_REGEX = re.compile( + "gapic\\/[0-9]+\\.[\\w.-]+ gax\\/[0-9]+\\.[\\w.-]+ gccl\\/[0-9]+\\.[\\w.-]+-" + + client_component + + " gl-python\\/[0-9]+\\.[\\w.-]+ grpc\\/[0-9]+\\.[\\w.-]+" + ) + if self.is_async: + patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") + else: + patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") + with patch as gapic_mock: + client = self._make_client(project="project-id") + wrapped_call_list = gapic_mock.call_args_list + assert len(wrapped_call_list) > 0 + for call in wrapped_call_list: + client_info = call.kwargs["client_info"] + assert client_info is not None, f"{call} has no client_info" + wrapped_user_agent_sorted = " ".join( + sorted(client_info.to_user_agent().split(" ")) + ) + assert VENEER_HEADER_REGEX.match( + wrapped_user_agent_sorted + ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" + client.close() + + def test_channel_pool_creation(self): + pool_size = 14 + with mock.patch.object( + grpc_helpers_async, "create_channel", AsyncMock() + ) as create_channel: + client = self._make_client(project="project-id", pool_size=pool_size) + assert create_channel.call_count == pool_size + client.close() + client = self._make_client(project="project-id", pool_size=pool_size) + pool_list = list(client.transport._grpc_channel._pool) + pool_set = set(client.transport._grpc_channel._pool) + assert len(pool_list) == len(pool_set) + client.close() + + def test_channel_pool_rotation(self): + pool_size = 7 + with mock.patch.object(PooledChannelAsync, "next_channel") as next_channel: + client = self._make_client(project="project-id", pool_size=pool_size) + assert len(client.transport._grpc_channel._pool) == pool_size + next_channel.reset_mock() + with mock.patch.object( + type(client.transport._grpc_channel._pool[0]), "unary_unary" + ) as unary_unary: + channel_next = None + for i in range(pool_size): + channel_last = channel_next + channel_next = client.transport.grpc_channel._pool[i] + assert channel_last != channel_next + next_channel.return_value = channel_next + client.transport.ping_and_warm() + assert next_channel.call_count == i + 1 + unary_unary.assert_called_once() + unary_unary.reset_mock() + client.close() + + def test_channel_pool_replace(self): + import time + + sleep_module = asyncio if self.is_async else time + with mock.patch.object(sleep_module, "sleep"): + pool_size = 7 + client = self._make_client(project="project-id", pool_size=pool_size) + for replace_idx in range(pool_size): + start_pool = [ + channel for channel in client.transport._grpc_channel._pool + ] + grace_period = 9 + with mock.patch.object( + type(client.transport._grpc_channel._pool[-1]), "close" + ) as close: + new_channel = client.transport.create_channel() + client.transport.replace_channel( + replace_idx, grace=grace_period, new_channel=new_channel + ) + close.assert_called_once() + if self.is_async: + close.assert_called_once_with(grace=grace_period) + close.assert_awaited_once() + assert client.transport._grpc_channel._pool[replace_idx] == new_channel + for i in range(pool_size): + if i != replace_idx: + assert client.transport._grpc_channel._pool[i] == start_pool[i] + else: + assert client.transport._grpc_channel._pool[i] != start_pool[i] + client.close() + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test__start_background_channel_refresh_sync(self): + client = self._make_client(project="project-id", use_emulator=False) + with pytest.raises(RuntimeError): + client._start_background_channel_refresh() + + def test__start_background_channel_refresh_tasks_exist(self): + client = self._make_client(project="project-id", use_emulator=False) + assert len(client._channel_refresh_tasks) > 0 + with mock.patch.object(asyncio, "create_task") as create_task: + client._start_background_channel_refresh() + create_task.assert_not_called() + client.close() + + @pytest.mark.parametrize("pool_size", [1, 3, 7]) + def test__start_background_channel_refresh(self, pool_size): + import concurrent.futures + + with mock.patch.object( + self._get_target_class(), "_ping_and_warm_instances", AsyncMock() + ) as ping_and_warm: + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + client._start_background_channel_refresh() + assert len(client._channel_refresh_tasks) == pool_size + for task in client._channel_refresh_tasks: + if self.is_async: + assert isinstance(task, asyncio.Task) + else: + assert isinstance(task, concurrent.futures.Future) + asyncio.sleep(0.1) + assert ping_and_warm.call_count == pool_size + for channel in client.transport._grpc_channel._pool: + ping_and_warm.assert_any_call(channel) + client.close() + + @pytest.mark.skipif( + sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" + ) + def test__start_background_channel_refresh_tasks_names(self): + pool_size = 3 + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + for i in range(pool_size): + name = client._channel_refresh_tasks[i].get_name() + assert str(i) in name + assert "BigtableDataClientAsync channel refresh " in name + client.close() + + def test__ping_and_warm_instances(self): + """test ping and warm with mocked asyncio.gather""" + from google.cloud.bigtable.data._sync.cross_sync import CrossSync + + client_mock = mock.Mock() + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + with mock.patch.object( + CrossSync._Sync_Impl, "gather_partials", AsyncMock() + ) as gather: + gather.side_effect = lambda partials, **kwargs: [None for _ in partials] + channel = mock.Mock() + client_mock._active_instances = [] + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel + ) + assert len(result) == 0 + assert gather.call_args.kwargs["return_exceptions"] is True + assert gather.call_args.kwargs["sync_executor"] == client_mock._executor + client_mock._active_instances = [ + (mock.Mock(), mock.Mock(), mock.Mock()) + ] * 4 + gather.reset_mock() + channel.reset_mock() + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel + ) + assert len(result) == 4 + gather.assert_called_once() + partial_list = gather.call_args.args[0] + assert len(partial_list) == 4 + if self.is_async: + gather.assert_awaited_once() + grpc_call_args = channel.unary_unary().call_args_list + for idx, (_, kwargs) in enumerate(grpc_call_args): + ( + expected_instance, + expected_table, + expected_app_profile, + ) = client_mock._active_instances[idx] + request = kwargs["request"] + assert request["name"] == expected_instance + assert request["app_profile_id"] == expected_app_profile + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert ( + metadata[0][1] + == f"name={expected_instance}&app_profile_id={expected_app_profile}" + ) + + def test__ping_and_warm_single_instance(self): + """should be able to call ping and warm with single instance""" + client_mock = mock.Mock() + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + gather_tuple = ( + (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") + ) + with mock.patch.object(*gather_tuple, AsyncMock()) as gather: + gather.side_effect = lambda *args, **kwargs: [mock.Mock() for _ in args] + if self.is_async: + gather.side_effect = lambda *args, **kwargs: [None for _ in args] + else: + gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] + channel = mock.Mock() + client_mock._active_instances = [mock.Mock()] * 100 + test_key = ("test-instance", "test-table", "test-app-profile") + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel, test_key + ) + assert len(result) == 1 + grpc_call_args = channel.unary_unary().call_args_list + assert len(grpc_call_args) == 1 + kwargs = grpc_call_args[0][1] + request = kwargs["request"] + assert request["name"] == "test-instance" + assert request["app_profile_id"] == "test-app-profile" + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert ( + metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" + ) + + @pytest.mark.parametrize( + "refresh_interval, wait_time, expected_sleep", + [(0, 0, 0), (0, 1, 0), (10, 0, 10), (10, 5, 5), (10, 10, 0), (10, 15, 0)], + ) + def test__manage_channel_first_sleep( + self, refresh_interval, wait_time, expected_sleep + ): + import threading + import time + + with mock.patch.object(time, "monotonic") as monotonic: + monotonic.return_value = 0 + sleep_tuple = ( + (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: + sleep.side_effect = asyncio.CancelledError + try: + client = self._make_client(project="project-id") + client._channel_init_time = -wait_time + client._manage_channel(0, refresh_interval, refresh_interval) + except asyncio.CancelledError: + pass + sleep.assert_called_once() + call_time = sleep.call_args[0][0] + assert ( + abs(call_time - expected_sleep) < 0.1 + ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" + client.close() + + def test__manage_channel_ping_and_warm(self): + """_manage channel should call ping and warm internally""" + import time + import threading + + client_mock = mock.Mock() + client_mock._is_closed.is_set.return_value = False + client_mock._channel_init_time = time.monotonic() + channel_list = [mock.Mock(), mock.Mock()] + client_mock.transport.channels = channel_list + new_channel = mock.Mock() + client_mock.transport.grpc_channel._create_channel.return_value = new_channel + sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + with mock.patch.object(*sleep_tuple): + client_mock.transport.replace_channel.side_effect = asyncio.CancelledError + ping_and_warm = client_mock._ping_and_warm_instances = AsyncMock() + try: + channel_idx = 1 + self._get_target_class()._manage_channel(client_mock, channel_idx, 10) + except asyncio.CancelledError: + pass + assert ping_and_warm.call_count == 2 + assert client_mock.transport.replace_channel.call_count == 1 + old_channel = channel_list[channel_idx] + assert old_channel != new_channel + called_with = [call[0][0] for call in ping_and_warm.call_args_list] + assert old_channel in called_with + assert new_channel in called_with + ping_and_warm.reset_mock() + try: + self._get_target_class()._manage_channel(client_mock, 0, 0, 0) + except asyncio.CancelledError: + pass + ping_and_warm.assert_called_once_with(new_channel) + + @pytest.mark.parametrize( + "refresh_interval, num_cycles, expected_sleep", + [(None, 1, 60 * 35), (10, 10, 100), (10, 1, 10)], + ) + def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sleep): + import time + import random + import threading + + channel_idx = 1 + with mock.patch.object(random, "uniform") as uniform: + uniform.side_effect = lambda min_, max_: min_ + with mock.patch.object(time, "time") as time_mock: + time_mock.return_value = 0 + sleep_tuple = ( + (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: + sleep.side_effect = [None for i in range(num_cycles - 1)] + [ + asyncio.CancelledError + ] + client = self._make_client(project="project-id") + with mock.patch.object(client.transport, "replace_channel"): + try: + if refresh_interval is not None: + client._manage_channel( + channel_idx, refresh_interval, refresh_interval + ) + else: + client._manage_channel(channel_idx) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + total_sleep = sum([call[0][0] for call in sleep.call_args_list]) + assert ( + abs(total_sleep - expected_sleep) < 0.1 + ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" + client.close() + + def test__manage_channel_random(self): + import random + import threading + + sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + with mock.patch.object(*sleep_tuple) as sleep: + with mock.patch.object(random, "uniform") as uniform: + uniform.return_value = 0 + try: + uniform.side_effect = asyncio.CancelledError + client = self._make_client(project="project-id", pool_size=1) + except asyncio.CancelledError: + uniform.side_effect = None + uniform.reset_mock() + sleep.reset_mock() + min_val = 200 + max_val = 205 + uniform.side_effect = lambda min_, max_: min_ + sleep.side_effect = [None, None, asyncio.CancelledError] + try: + with mock.patch.object(client.transport, "replace_channel"): + client._manage_channel(0, min_val, max_val) + except asyncio.CancelledError: + pass + assert uniform.call_count == 3 + uniform_args = [call[0] for call in uniform.call_args_list] + for found_min, found_max in uniform_args: + assert found_min == min_val + assert found_max == max_val + + @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) + def test__manage_channel_refresh(self, num_cycles): + import threading + + expected_grace = 9 + expected_refresh = 0.5 + channel_idx = 1 + grpc_lib = grpc.aio if self.is_async else grpc + new_channel = grpc_lib.insecure_channel("localhost:8080") + with mock.patch.object( + PooledBigtableGrpcAsyncIOTransport, "replace_channel" + ) as replace_channel: + sleep_tuple = ( + (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: + sleep.side_effect = [None for i in range(num_cycles)] + [ + asyncio.CancelledError + ] + with mock.patch.object( + grpc_helpers_async, "create_channel" + ) as create_channel: + create_channel.return_value = new_channel + with mock.patch.object( + self._get_target_class(), "_start_background_channel_refresh" + ): + client = self._make_client( + project="project-id", use_emulator=False + ) + create_channel.reset_mock() + try: + client._manage_channel( + channel_idx, + refresh_interval_min=expected_refresh, + refresh_interval_max=expected_refresh, + grace_period=expected_grace, + ) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + 1 + assert create_channel.call_count == num_cycles + assert replace_channel.call_count == num_cycles + for call in replace_channel.call_args_list: + args, kwargs = call + assert args[0] == channel_idx + assert kwargs["grace"] == expected_grace + assert kwargs["new_channel"] == new_channel + client.close() + + def test__register_instance(self): + """test instance registration""" + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: f"prefix/{b}" + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_tasks = [] + client_mock._start_background_channel_refresh.side_effect = ( + lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + ) + mock_channels = [mock.Mock() for i in range(5)] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = AsyncMock() + table_mock = mock.Mock() + self._get_target_class()._register_instance( + client_mock, "instance-1", table_mock + ) + assert client_mock._start_background_channel_refresh.call_count == 1 + expected_key = ( + "prefix/instance-1", + table_mock.table_name, + table_mock.app_profile_id, + ) + assert len(active_instances) == 1 + assert expected_key == tuple(list(active_instances)[0]) + assert len(instance_owners) == 1 + assert expected_key == tuple(list(instance_owners)[0]) + assert client_mock._channel_refresh_tasks + table_mock2 = mock.Mock() + self._get_target_class()._register_instance( + client_mock, "instance-2", table_mock2 + ) + assert client_mock._start_background_channel_refresh.call_count == 1 + assert client_mock._ping_and_warm_instances.call_count == len(mock_channels) + for channel in mock_channels: + assert channel in [ + call[0][0] + for call in client_mock._ping_and_warm_instances.call_args_list + ] + assert len(active_instances) == 2 + assert len(instance_owners) == 2 + expected_key2 = ( + "prefix/instance-2", + table_mock2.table_name, + table_mock2.app_profile_id, + ) + assert any( + [ + expected_key2 == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) + assert any( + [ + expected_key2 == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) + + @pytest.mark.parametrize( + "insert_instances,expected_active,expected_owner_keys", + [ + ([("i", "t", None)], [("i", "t", None)], [("i", "t", None)]), + ([("i", "t", "p")], [("i", "t", "p")], [("i", "t", "p")]), + ([("1", "t", "p"), ("1", "t", "p")], [("1", "t", "p")], [("1", "t", "p")]), + ( + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + ), + ], + ) + def test__register_instance_state( + self, insert_instances, expected_active, expected_owner_keys + ): + """test that active_instances and instance_owners are updated as expected""" + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: b + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_tasks = [] + client_mock._start_background_channel_refresh.side_effect = ( + lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + ) + mock_channels = [mock.Mock() for i in range(5)] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = AsyncMock() + table_mock = mock.Mock() + for instance, table, profile in insert_instances: + table_mock.table_name = table + table_mock.app_profile_id = profile + self._get_target_class()._register_instance( + client_mock, instance, table_mock + ) + assert len(active_instances) == len(expected_active) + assert len(instance_owners) == len(expected_owner_keys) + for expected in expected_active: + assert any( + [ + expected == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) + for expected in expected_owner_keys: + assert any( + [ + expected == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) + + def test__remove_instance_registration(self): + client = self._make_client(project="project-id") + table = mock.Mock() + client._register_instance("instance-1", table) + client._register_instance("instance-2", table) + assert len(client._active_instances) == 2 + assert len(client._instance_owners.keys()) == 2 + instance_1_path = client._gapic_client.instance_path( + client.project, "instance-1" + ) + instance_1_key = (instance_1_path, table.table_name, table.app_profile_id) + instance_2_path = client._gapic_client.instance_path( + client.project, "instance-2" + ) + instance_2_key = (instance_2_path, table.table_name, table.app_profile_id) + assert len(client._instance_owners[instance_1_key]) == 1 + assert list(client._instance_owners[instance_1_key])[0] == id(table) + assert len(client._instance_owners[instance_2_key]) == 1 + assert list(client._instance_owners[instance_2_key])[0] == id(table) + success = client._remove_instance_registration("instance-1", table) + assert success + assert len(client._active_instances) == 1 + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 1 + assert client._active_instances == {instance_2_key} + success = client._remove_instance_registration("fake-key", table) + assert not success + assert len(client._active_instances) == 1 + client.close() + + def test__multiple_table_registration(self): + """ + registering with multiple tables with the same key should + add multiple owners to instance_owners, but only keep one copy + of shared key in active_instances + """ + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + with self._make_client(project="project-id") as client: + with client.get_table("instance_1", "table_1") as table_1: + instance_1_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 1 + assert len(client._active_instances) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + with client.get_table("instance_1", "table_1") as table_2: + assert len(client._instance_owners[instance_1_key]) == 2 + assert len(client._active_instances) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + with client.get_table("instance_1", "table_3") as table_3: + instance_3_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_3_key = _WarmedInstanceKey( + instance_3_path, table_3.table_name, table_3.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 2 + assert len(client._instance_owners[instance_3_key]) == 1 + assert len(client._active_instances) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + assert id(table_3) in client._instance_owners[instance_3_key] + assert len(client._active_instances) == 1 + assert instance_1_key in client._active_instances + assert id(table_2) not in client._instance_owners[instance_1_key] + assert len(client._active_instances) == 0 + assert instance_1_key not in client._active_instances + assert len(client._instance_owners[instance_1_key]) == 0 + + def test__multiple_instance_registration(self): + """ + registering with multiple instance keys should update the key + in instance_owners and active_instances + """ + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + with self._make_client(project="project-id") as client: + with client.get_table("instance_1", "table_1") as table_1: + with client.get_table("instance_2", "table_2") as table_2: + instance_1_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id + ) + instance_2_path = client._gapic_client.instance_path( + client.project, "instance_2" + ) + instance_2_key = _WarmedInstanceKey( + instance_2_path, table_2.table_name, table_2.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 1 + assert len(client._instance_owners[instance_2_key]) == 1 + assert len(client._active_instances) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_2_key] + assert len(client._active_instances) == 1 + assert instance_1_key in client._active_instances + assert len(client._instance_owners[instance_2_key]) == 0 + assert len(client._instance_owners[instance_1_key]) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + assert len(client._active_instances) == 0 + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 0 + + def test_get_table(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + client = self._make_client(project="project-id") + assert not client._active_instances + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + table = client.get_table( + expected_instance_id, expected_table_id, expected_app_profile_id + ) + asyncio.sleep(0) + assert isinstance(table, TestTableAsync._get_target_class()) + assert table.table_id == expected_table_id + assert ( + table.table_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert table.instance_id == expected_instance_id + assert ( + table.instance_name + == f"projects/{client.project}/instances/{expected_instance_id}" + ) + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + client.close() + + def test_get_table_arg_passthrough(self): + """All arguments passed in get_table should be sent to constructor""" + with self._make_client(project="project-id") as client: + with mock.patch.object( + TestTableAsync._get_target_class(), "__init__" + ) as mock_constructor: + mock_constructor.return_value = None + assert not client._active_instances + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_args = (1, "test", {"test": 2}) + expected_kwargs = {"hello": "world", "test": 2} + client.get_table( + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, + ) + mock_constructor.assert_called_once_with( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, + ) + + def test_get_table_context_manager(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_project_id = "project-id" + with mock.patch.object( + TestTableAsync._get_target_class(), "close" + ) as close_mock: + with self._make_client(project=expected_project_id) as client: + with client.get_table( + expected_instance_id, expected_table_id, expected_app_profile_id + ) as table: + asyncio.sleep(0) + assert isinstance(table, TestTableAsync._get_target_class()) + assert table.table_id == expected_table_id + assert ( + table.table_name + == f"projects/{expected_project_id}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert table.instance_id == expected_instance_id + assert ( + table.instance_name + == f"projects/{expected_project_id}/instances/{expected_instance_id}" + ) + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + assert close_mock.call_count == 1 + + def test_multiple_pool_sizes(self): + pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] + for pool_size in pool_sizes: + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client._channel_refresh_tasks) == pool_size + client_duplicate = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client_duplicate._channel_refresh_tasks) == pool_size + assert str(pool_size) in str(client.transport) + client.close() + client_duplicate.close() + + def test_close(self): + pool_size = 7 + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client._channel_refresh_tasks) == pool_size + tasks_list = list(client._channel_refresh_tasks) + for task in client._channel_refresh_tasks: + assert not task.done() + with mock.patch.object( + PooledBigtableGrpcAsyncIOTransport, "close", AsyncMock() + ) as close_mock: + client.close() + close_mock.assert_called_once() + close_mock.assert_awaited() + for task in tasks_list: + assert task.done() + assert client._channel_refresh_tasks == [] + + def test_close_with_timeout(self): + from google.cloud.bigtable.data._sync.cross_sync import CrossSync + + pool_size = 7 + expected_timeout = 19 + client = self._make_client(project="project-id", pool_size=pool_size) + tasks = list(client._channel_refresh_tasks) + with mock.patch.object( + CrossSync._Sync_Impl, "wait", AsyncMock() + ) as wait_for_mock: + client.close(timeout=expected_timeout) + wait_for_mock.assert_called_once() + wait_for_mock.assert_awaited() + assert wait_for_mock.call_args[1]["timeout"] == expected_timeout + client._channel_refresh_tasks = tasks + client.close() + + def test_context_manager(self): + close_mock = AsyncMock() + true_close = None + with self._make_client(project="project-id") as client: + true_close = client.close() + client.close = close_mock + for task in client._channel_refresh_tasks: + assert not task.done() + assert client.project == "project-id" + assert client._active_instances == set() + close_mock.assert_not_called() + close_mock.assert_called_once() + close_mock.assert_awaited() + true_close + + def test_client_ctor_sync(self): + with pytest.warns(RuntimeWarning) as warnings: + client = self._make_client(project="project-id", use_emulator=False) + expected_warning = [w for w in warnings if "client.py" in w.filename] + assert len(expected_warning) == 1 + assert ( + "BigtableDataClientAsync should be started in an asyncio event loop." + in str(expected_warning[0].message) + ) + assert client.project == "project-id" + assert client._channel_refresh_tasks == [] + + +class TestBulkMutateRows: + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + def _mock_response(self, response_list): + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc import status_pb2 + + statuses = [] + for response in response_list: + if isinstance(response, core_exceptions.GoogleAPICallError): + statuses.append( + status_pb2.Status( + message=str(response), code=response.grpc_status_code.value[0] + ) + ) + else: + statuses.append(status_pb2.Status(code=0)) + entries = [ + MutateRowsResponse.Entry(index=i, status=statuses[i]) + for i in range(len(response_list)) + ] + + def generator(): + yield MutateRowsResponse(entries=entries) + + return generator() + + @pytest.mark.parametrize( + "mutation_arg", + [ + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=1234567890 + ) + ], + [mutations.DeleteRangeFromColumn("family", b"qualifier")], + [mutations.DeleteAllFromFamily("family")], + [mutations.DeleteAllFromRow()], + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromRow(), + ], + ], + ) + def test_bulk_mutate_rows(self, mutation_arg): + """Test mutations with no errors""" + expected_attempt_timeout = 19 + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.return_value = self._mock_response([None]) + bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg) + table.bulk_mutate_rows( + [bulk_mutation], attempt_timeout=expected_attempt_timeout + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"] == [bulk_mutation._to_pb()] + assert kwargs["timeout"] == expected_attempt_timeout + assert kwargs["retry"] is None + + def test_bulk_mutate_rows_multiple_entries(self): + """Test mutations with no errors""" + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.return_value = self._mock_response([None, None]) + mutation_list = [mutations.DeleteAllFromRow()] + entry_1 = mutations.RowMutationEntry(b"row_key_1", mutation_list) + entry_2 = mutations.RowMutationEntry(b"row_key_2", mutation_list) + table.bulk_mutate_rows([entry_1, entry_2]) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"][0] == entry_1._to_pb() + assert kwargs["entries"][1] == entry_2._to_pb() + + @pytest.mark.parametrize( + "exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_rows_idempotent_mutation_error_retryable(self, exception): + """Individual idempotent mutations should be retried if they fail with a retryable error""" + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], exception) + assert isinstance( + cause.exceptions[-1], core_exceptions.DeadlineExceeded + ) + + @pytest.mark.parametrize( + "exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + core_exceptions.Aborted, + ], + ) + def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable(self, exception): + """Individual idempotent mutations should not be retried if they fail with a non-retryable error""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, exception) + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_idempotent_retryable_request_errors(self, retryable_exception): + """Individual idempotent mutations should be retried if the request fails with a retryable error""" + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_rows_non_idempotent_retryable_errors( + self, retryable_exception + ): + """Non-Idempotent mutations should never be retried""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [retryable_exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", -1 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is False + table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, retryable_exception) + + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + ], + ) + def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception): + """If the request fails with a non-retryable error, mutations should not be retried""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, non_retryable_exception) + + def test_bulk_mutate_error_index(self): + """Test partial failure, partial success. Errors should be associated with the correct index""" + from google.api_core.exceptions import ( + DeadlineExceeded, + ServiceUnavailable, + FailedPrecondition, + ) + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = [ + self._mock_response([None, ServiceUnavailable("mock"), None]), + self._mock_response([DeadlineExceeded("mock")]), + self._mock_response([FailedPrecondition("final")]), + ] + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entries = [ + mutations.RowMutationEntry( + f"row_key_{i}".encode(), [mutation] + ) + for i in range(3) + ] + assert mutation.is_idempotent() is True + table.bulk_mutate_rows(entries, operation_timeout=1000) + assert len(e.value.exceptions) == 1 + failed = e.value.exceptions[0] + assert isinstance(failed, FailedMutationEntryError) + assert failed.index == 1 + assert failed.entry == entries[1] + cause = failed.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert len(cause.exceptions) == 3 + assert isinstance(cause.exceptions[0], ServiceUnavailable) + assert isinstance(cause.exceptions[1], DeadlineExceeded) + assert isinstance(cause.exceptions[2], FailedPrecondition) + + def test_bulk_mutate_error_recovery(self): + """If an error occurs, then resolves, no exception should be raised""" + from google.api_core.exceptions import DeadlineExceeded + + with self._make_client(project="project") as client: + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.side_effect = [ + self._mock_response([DeadlineExceeded("mock")]), + self._mock_response([None]), + ] + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entries = [ + mutations.RowMutationEntry(f"row_key_{i}".encode(), [mutation]) + for i in range(3) + ] + table.bulk_mutate_rows(entries, operation_timeout=1000) + + +class TestCheckAndMutateRow: + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + @pytest.mark.parametrize("gapic_result", [True, False]) + def test_check_and_mutate(self, gapic_result): + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + app_profile = "app_profile_id" + with self._make_client() as client: + with client.get_table( + "instance", "table", app_profile_id=app_profile + ) as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=gapic_result + ) + row_key = b"row_key" + predicate = None + true_mutations = [mock.Mock()] + false_mutations = [mock.Mock(), mock.Mock()] + operation_timeout = 0.2 + found = table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutations, + false_case_mutations=false_mutations, + operation_timeout=operation_timeout, + ) + assert found == gapic_result + kwargs = mock_gapic.call_args[1] + assert kwargs["table_name"] == table.table_name + assert kwargs["row_key"] == row_key + assert kwargs["predicate_filter"] == predicate + assert kwargs["true_mutations"] == [ + m._to_pb() for m in true_mutations + ] + assert kwargs["false_mutations"] == [ + m._to_pb() for m in false_mutations + ] + assert kwargs["app_profile_id"] == app_profile + assert kwargs["timeout"] == operation_timeout + assert kwargs["retry"] is None + + def test_check_and_mutate_bad_timeout(self): + """Should raise error if operation_timeout < 0""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=[mock.Mock()], + false_case_mutations=[], + operation_timeout=-1, + ) + assert str(e.value) == "operation_timeout must be greater than 0" + + def test_check_and_mutate_single_mutations(self): + """if single mutations are passed, they should be internally wrapped in a list""" + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + true_mutation = SetCell("family", b"qualifier", b"value") + false_mutation = SetCell("family", b"qualifier", b"value") + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["true_mutations"] == [true_mutation._to_pb()] + assert kwargs["false_mutations"] == [false_mutation._to_pb()] + + def test_check_and_mutate_predicate_object(self): + """predicate filter should be passed to gapic request""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + mock_predicate = mock.Mock() + predicate_pb = {"predicate": "dict"} + mock_predicate._to_pb.return_value = predicate_pb + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + table.check_and_mutate_row( + b"row_key", mock_predicate, false_case_mutations=[mock.Mock()] + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["predicate_filter"] == predicate_pb + assert mock_predicate._to_pb.call_count == 1 + assert kwargs["retry"] is None + + def test_check_and_mutate_mutations_parsing(self): + """mutations objects should be converted to protos""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + from google.cloud.bigtable.data.mutations import DeleteAllFromRow + + mutations = [mock.Mock() for _ in range(5)] + for idx, mutation in enumerate(mutations): + mutation._to_pb.return_value = f"fake {idx}" + mutations.append(DeleteAllFromRow()) + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=mutations[0:2], + false_case_mutations=mutations[2:], + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["true_mutations"] == ["fake 0", "fake 1"] + assert kwargs["false_mutations"] == [ + "fake 2", + "fake 3", + "fake 4", + DeleteAllFromRow()._to_pb(), + ] + assert all( + (mutation._to_pb.call_count == 1 for mutation in mutations[:5]) + ) + + +class TestMutateRow: + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + @pytest.mark.parametrize( + "mutation_arg", + [ + mutations.SetCell("family", b"qualifier", b"value"), + mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=1234567890 + ), + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromFamily("family"), + mutations.DeleteAllFromRow(), + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromRow(), + ], + ], + ) + def test_mutate_row(self, mutation_arg): + """Test mutations with no errors""" + expected_attempt_timeout = 19 + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.return_value = None + table.mutate_row( + "row_key", + mutation_arg, + attempt_timeout=expected_attempt_timeout, + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0].kwargs + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["row_key"] == b"row_key" + formatted_mutations = ( + [mutation._to_pb() for mutation in mutation_arg] + if isinstance(mutation_arg, list) + else [mutation_arg._to_pb()] + ) + assert kwargs["mutations"] == formatted_mutations + assert kwargs["timeout"] == expected_attempt_timeout + assert kwargs["retry"] is None + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_mutate_row_retryable_errors(self, retryable_exception): + from google.api_core.exceptions import DeadlineExceeded + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + mutation = mutations.DeleteAllFromRow() + assert mutation.is_idempotent() is True + table.mutate_row("row_key", mutation, operation_timeout=0.01) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_mutate_row_non_idempotent_retryable_errors(self, retryable_exception): + """Non-idempotent mutations should not be retried""" + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(retryable_exception): + mutation = mutations.SetCell( + "family", b"qualifier", b"value", -1 + ) + assert mutation.is_idempotent() is False + table.mutate_row("row_key", mutation, operation_timeout=0.2) + + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + core_exceptions.Aborted, + ], + ) + def test_mutate_row_non_retryable_errors(self, non_retryable_exception): + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + mutation = mutations.SetCell( + "family", + b"qualifier", + b"value", + timestamp_micros=1234567890, + ) + assert mutation.is_idempotent() is True + table.mutate_row("row_key", mutation, operation_timeout=0.2) + + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_mutate_row_metadata(self, include_app_profile): + """request should attach metadata headers""" + profile = "profile" if include_app_profile else None + with self._make_client() as client: + with client.get_table("i", "t", app_profile_id=profile) as table: + with mock.patch.object( + client._gapic_client, "mutate_row", AsyncMock() + ) as read_rows: + table.mutate_row("rk", mock.Mock()) + kwargs = read_rows.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata + + @pytest.mark.parametrize("mutations", [[], None]) + def test_mutate_row_no_mutations(self, mutations): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.mutate_row("key", mutations=mutations) + assert e.value.args[0] == "No mutations provided" + + +class TestReadModifyWriteRow: + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + @pytest.mark.parametrize( + "call_rules,expected_rules", + [ + ( + AppendValueRule("f", "c", b"1"), + [AppendValueRule("f", "c", b"1")._to_pb()], + ), + ( + [AppendValueRule("f", "c", b"1")], + [AppendValueRule("f", "c", b"1")._to_pb()], + ), + (IncrementRule("f", "c", 1), [IncrementRule("f", "c", 1)._to_pb()]), + ( + [AppendValueRule("f", "c", b"1"), IncrementRule("f", "c", 1)], + [ + AppendValueRule("f", "c", b"1")._to_pb(), + IncrementRule("f", "c", 1)._to_pb(), + ], + ), + ], + ) + def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): + """Test that the gapic call is called with given rules""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row("key", call_rules) + assert mock_gapic.call_count == 1 + found_kwargs = mock_gapic.call_args_list[0][1] + assert found_kwargs["rules"] == expected_rules + assert found_kwargs["retry"] is None + + @pytest.mark.parametrize("rules", [[], None]) + def test_read_modify_write_no_rules(self, rules): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.read_modify_write_row("key", rules=rules) + assert e.value.args[0] == "rules must contain at least one item" + + def test_read_modify_write_call_defaults(self): + instance = "instance1" + table_id = "table1" + project = "project1" + row_key = "row_key1" + with self._make_client(project=project) as client: + with client.get_table(instance, table_id) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert ( + kwargs["table_name"] + == f"projects/{project}/instances/{instance}/tables/{table_id}" + ) + assert kwargs["app_profile_id"] is None + assert kwargs["row_key"] == row_key.encode() + assert kwargs["timeout"] > 1 + + def test_read_modify_write_call_overrides(self): + row_key = b"row_key1" + expected_timeout = 12345 + profile_id = "profile1" + with self._make_client() as client: + with client.get_table( + "instance", "table_id", app_profile_id=profile_id + ) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row( + row_key, mock.Mock(), operation_timeout=expected_timeout + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["app_profile_id"] is profile_id + assert kwargs["row_key"] == row_key + assert kwargs["timeout"] == expected_timeout + + def test_read_modify_write_string_key(self): + row_key = "string_row_key1" + with self._make_client() as client: + with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["row_key"] == row_key.encode() + + def test_read_modify_write_row_building(self): + """results from gapic call should be used to construct row""" + from google.cloud.bigtable.data.row import Row + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + from google.cloud.bigtable_v2.types import Row as RowPB + + mock_response = ReadModifyWriteRowResponse(row=RowPB()) + with self._make_client() as client: + with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + with mock.patch.object(Row, "_from_pb") as constructor_mock: + mock_gapic.return_value = mock_response + table.read_modify_write_row("key", mock.Mock()) + assert constructor_mock.call_count == 1 + constructor_mock.assert_called_once_with(mock_response.row) + + +class TestReadRows: + """ + Tests for table.read_rows and related methods. + """ + + @staticmethod + def _get_operation_class(): + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + + return _ReadRowsOperationAsync + + def _make_client(self, *args, **kwargs): + return TestBigtableDataClientAsync._make_client(*args, **kwargs) + + def _make_table(self, *args, **kwargs): + from google.cloud.bigtable.data._async.client import TableAsync + + client_mock = mock.Mock() + client_mock._register_instance.side_effect = ( + lambda *args, **kwargs: asyncio.sleep(0) + ) + client_mock._remove_instance_registration.side_effect = ( + lambda *args, **kwargs: asyncio.sleep(0) + ) + kwargs["instance_id"] = kwargs.get( + "instance_id", args[0] if args else "instance" + ) + kwargs["table_id"] = kwargs.get( + "table_id", args[1] if len(args) > 1 else "table" + ) + client_mock._gapic_client.table_path.return_value = kwargs["table_id"] + client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] + return TableAsync(client_mock, *args, **kwargs) + + def _make_stats(self): + from google.cloud.bigtable_v2.types import RequestStats + from google.cloud.bigtable_v2.types import FullReadStatsView + from google.cloud.bigtable_v2.types import ReadIterationStats + + return RequestStats( + full_read_stats_view=FullReadStatsView( + read_iteration_stats=ReadIterationStats( + rows_seen_count=1, + rows_returned_count=2, + cells_seen_count=3, + cells_returned_count=4, + ) + ) + ) + + @staticmethod + def _make_chunk(*args, **kwargs): + from google.cloud.bigtable_v2 import ReadRowsResponse + + kwargs["row_key"] = kwargs.get("row_key", b"row_key") + kwargs["family_name"] = kwargs.get("family_name", "family_name") + kwargs["qualifier"] = kwargs.get("qualifier", b"qualifier") + kwargs["value"] = kwargs.get("value", b"value") + kwargs["commit_row"] = kwargs.get("commit_row", True) + return ReadRowsResponse.CellChunk(*args, **kwargs) + + @staticmethod + def _make_gapic_stream( + chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0 + ): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list, sleep_time): + self.chunk_list = chunk_list + self.idx = -1 + self.sleep_time = sleep_time + + def __aiter__(self): + return self + + def __anext__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + if sleep_time: + asyncio.sleep(self.sleep_time) + chunk = self.chunk_list[self.idx] + if isinstance(chunk, Exception): + raise chunk + else: + return ReadRowsResponse(chunks=[chunk]) + raise StopAsyncIteration + + def cancel(self): + pass + + return mock_stream(chunk_list, sleep_time) + + def execute_fn(self, table, *args, **kwargs): + return table.read_rows(*args, **kwargs) + + def test_read_rows(self): + query = ReadRowsQuery() + chunks = [ + self._make_chunk(row_key=b"test_1"), + self._make_chunk(row_key=b"test_2"), + ] + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + results = self.execute_fn(table, query, operation_timeout=3) + assert len(results) == 2 + assert results[0].row_key == b"test_1" + assert results[1].row_key == b"test_2" + + def test_read_rows_stream(self): + query = ReadRowsQuery() + chunks = [ + self._make_chunk(row_key=b"test_1"), + self._make_chunk(row_key=b"test_2"), + ] + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + gen = table.read_rows_stream(query, operation_timeout=3) + results = [row for row in gen] + assert len(results) == 2 + assert results[0].row_key == b"test_1" + assert results[1].row_key == b"test_2" + + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_read_rows_query_matches_request(self, include_app_profile): + from google.cloud.bigtable.data import RowRange + from google.cloud.bigtable.data.row_filters import PassAllFilter + + app_profile_id = "app_profile_id" if include_app_profile else None + with self._make_table(app_profile_id=app_profile_id) as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream([]) + row_keys = [b"test_1", "test_2"] + row_ranges = RowRange("1start", "2end") + filter_ = PassAllFilter(True) + limit = 99 + query = ReadRowsQuery( + row_keys=row_keys, + row_ranges=row_ranges, + row_filter=filter_, + limit=limit, + ) + results = table.read_rows(query, operation_timeout=3) + assert len(results) == 0 + call_request = read_rows.call_args_list[0][0][0] + query_pb = query._to_pb(table) + assert call_request == query_pb + + @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) + def test_read_rows_timeout(self, operation_timeout): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + query = ReadRowsQuery() + chunks = [self._make_chunk(row_key=b"test_1")] + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=0.15 + ) + try: + table.read_rows(query, operation_timeout=operation_timeout) + except core_exceptions.DeadlineExceeded as e: + assert ( + e.message + == f"operation_timeout of {operation_timeout:0.1f}s exceeded" + ) + + @pytest.mark.parametrize( + "per_request_t, operation_t, expected_num", + [(0.05, 0.08, 2), (0.05, 0.14, 3), (0.05, 0.24, 5)], + ) + def test_read_rows_attempt_timeout(self, per_request_t, operation_t, expected_num): + """ + Ensures that the attempt_timeout is respected and that the number of + requests is as expected. + + operation_timeout does not cancel the request, so we expect the number of + requests to be the ceiling of operation_timeout / attempt_timeout. + """ + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + expected_last_timeout = operation_t - (expected_num - 1) * per_request_t + with mock.patch("random.uniform", side_effect=lambda a, b: 0): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=per_request_t + ) + query = ReadRowsQuery() + chunks = [core_exceptions.DeadlineExceeded("mock deadline")] + try: + table.read_rows( + query, + operation_timeout=operation_t, + attempt_timeout=per_request_t, + ) + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + if expected_num == 0: + assert retry_exc is None + else: + assert type(retry_exc) is RetryExceptionGroup + assert f"{expected_num} failed attempts" in str(retry_exc) + assert len(retry_exc.exceptions) == expected_num + for sub_exc in retry_exc.exceptions: + assert sub_exc.message == "mock deadline" + assert read_rows.call_count == expected_num + for _, call_kwargs in read_rows.call_args_list[:-1]: + assert call_kwargs["timeout"] == per_request_t + assert call_kwargs["retry"] is None + assert ( + abs( + read_rows.call_args_list[-1][1]["timeout"] + - expected_last_timeout + ) + < 0.05 + ) + + @pytest.mark.parametrize( + "exc_type", + [ + core_exceptions.Aborted, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ], + ) + def test_read_rows_retryable_error(self, exc_type): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + try: + table.read_rows(query, operation_timeout=0.1) + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + root_cause = retry_exc.exceptions[0] + assert type(root_cause) is exc_type + assert root_cause == expected_error + + @pytest.mark.parametrize( + "exc_type", + [ + core_exceptions.Cancelled, + core_exceptions.PreconditionFailed, + core_exceptions.NotFound, + core_exceptions.PermissionDenied, + core_exceptions.Conflict, + core_exceptions.InternalServerError, + core_exceptions.TooManyRequests, + core_exceptions.ResourceExhausted, + InvalidChunk, + ], + ) + def test_read_rows_non_retryable_error(self, exc_type): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + try: + table.read_rows(query, operation_timeout=0.1) + except exc_type as e: + assert e == expected_error + + def test_read_rows_revise_request(self): + """Ensure that _revise_request is called between retries""" + from google.cloud.bigtable.data.exceptions import InvalidChunk + from google.cloud.bigtable_v2.types import RowSet + + return_val = RowSet() + with mock.patch.object( + self._get_operation_class(), "_revise_request_rowset" + ) as revise_rowset: + revise_rowset.return_value = return_val + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + row_keys = [b"test_1", b"test_2", b"test_3"] + query = ReadRowsQuery(row_keys=row_keys) + chunks = [ + self._make_chunk(row_key=b"test_1"), + core_exceptions.Aborted("mock retryable error"), + ] + try: + table.read_rows(query) + except InvalidChunk: + revise_rowset.assert_called() + first_call_kwargs = revise_rowset.call_args_list[0].kwargs + assert first_call_kwargs["row_set"] == query._to_pb(table).rows + assert first_call_kwargs["last_seen_row_key"] == b"test_1" + revised_call = read_rows.call_args_list[1].args[0] + assert revised_call.rows == return_val + + def test_read_rows_default_timeouts(self): + """Ensure that the default timeouts are set on the read rows operation when not overridden""" + operation_timeout = 8 + attempt_timeout = 4 + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: + mock_op.side_effect = RuntimeError("mock error") + with self._make_table( + default_read_rows_operation_timeout=operation_timeout, + default_read_rows_attempt_timeout=attempt_timeout, + ) as table: + try: + table.read_rows(ReadRowsQuery()) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["attempt_timeout"] == attempt_timeout + + def test_read_rows_default_timeout_override(self): + """When timeouts are passed, they overwrite default values""" + operation_timeout = 8 + attempt_timeout = 4 + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: + mock_op.side_effect = RuntimeError("mock error") + with self._make_table( + default_operation_timeout=99, default_attempt_timeout=97 + ) as table: + try: + table.read_rows( + ReadRowsQuery(), + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + ) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["attempt_timeout"] == attempt_timeout + + def test_read_row(self): + """Test reading a single row""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + row = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert row == expected_result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + + def test_read_row_w_filter(self): + """Test reading a single row with an added filter""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + mock_filter = mock.Mock() + expected_filter = {"filter": "mock filter"} + mock_filter._to_dict.return_value = expected_filter + row = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + row_filter=expected_filter, + ) + assert row == expected_result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + assert query.filter == expected_filter + + def test_read_row_no_response(self): + """should return None if row does not exist""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = lambda *args, **kwargs: [] + expected_op_timeout = 8 + expected_req_timeout = 4 + result = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert result is None + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + + @pytest.mark.parametrize( + "return_value,expected_result", + [([], False), ([object()], True), ([object(), object()], True)], + ) + def test_row_exists(self, return_value, expected_result): + """Test checking for row existence""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = lambda *args, **kwargs: return_value + expected_op_timeout = 1 + expected_req_timeout = 2 + result = table.row_exists( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert expected_result == result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + expected_filter = { + "chain": { + "filters": [ + {"cells_per_row_limit_filter": 1}, + {"strip_value_transformer": True}, + ] + } + } + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + assert query.filter._to_dict() == expected_filter + + +class TestReadRowsSharded: + def _make_client(self, *args, **kwargs): + return TestBigtableDataClientAsync._make_client(*args, **kwargs) + + def test_read_rows_sharded_empty_query(self): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as exc: + table.read_rows_sharded([]) + assert "empty sharded_query" in str(exc.value) + + def test_read_rows_sharded_multiple_queries(self): + """Test with multiple queries. Should return results from both""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "read_rows" + ) as read_rows: + read_rows.side_effect = ( + lambda *args, **kwargs: TestReadRowsAsync._make_gapic_stream( + [ + TestReadRowsAsync._make_chunk(row_key=k) + for k in args[0].rows.row_keys + ] + ) + ) + query_1 = ReadRowsQuery(b"test_1") + query_2 = ReadRowsQuery(b"test_2") + result = table.read_rows_sharded([query_1, query_2]) + assert len(result) == 2 + assert result[0].row_key == b"test_1" + assert result[1].row_key == b"test_2" + + @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24]) + def test_read_rows_sharded_multiple_queries_calls(self, n_queries): + """Each query should trigger a separate read_rows call""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + query_list = [ReadRowsQuery() for _ in range(n_queries)] + table.read_rows_sharded(query_list) + assert read_rows.call_count == n_queries + + def test_read_rows_sharded_errors(self): + """Errors should be exposed as ShardedReadRowsExceptionGroups""" + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedQueryShardError + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = RuntimeError("mock error") + query_1 = ReadRowsQuery(b"test_1") + query_2 = ReadRowsQuery(b"test_2") + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + table.read_rows_sharded([query_1, query_2]) + exc_group = exc.value + assert isinstance(exc_group, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == 2 + assert isinstance(exc.value.exceptions[0], FailedQueryShardError) + assert isinstance(exc.value.exceptions[0].__cause__, RuntimeError) + assert exc.value.exceptions[0].index == 0 + assert exc.value.exceptions[0].query == query_1 + assert isinstance(exc.value.exceptions[1], FailedQueryShardError) + assert isinstance(exc.value.exceptions[1].__cause__, RuntimeError) + assert exc.value.exceptions[1].index == 1 + assert exc.value.exceptions[1].query == query_2 + + def test_read_rows_sharded_concurrent(self): + """Ensure sharded requests are concurrent""" + import time + + def mock_call(*args, **kwargs): + asyncio.sleep(0.1) + return [mock.Mock()] + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(10)] + start_time = time.monotonic() + result = table.read_rows_sharded(queries) + call_time = time.monotonic() - start_time + assert read_rows.call_count == 10 + assert len(result) == 10 + assert call_time < 0.2 + + def test_read_rows_sharded_concurrency_limit(self): + """ + Only 10 queries should be processed concurrently. Others should be queued + + Should start a new query as soon as previous finishes + """ + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT + + assert _CONCURRENCY_LIMIT == 10 + num_queries = 15 + increment_time = 0.05 + max_time = increment_time * (_CONCURRENCY_LIMIT - 1) + rpc_times = [min(i * increment_time, max_time) for i in range(num_queries)] + + def mock_call(*args, **kwargs): + next_sleep = rpc_times.pop(0) + asyncio.sleep(next_sleep) + return [mock.Mock()] + + starting_timeout = 10 + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(num_queries)] + table.read_rows_sharded(queries, operation_timeout=starting_timeout) + assert read_rows.call_count == num_queries + rpc_start_list = [ + starting_timeout - kwargs["operation_timeout"] + for _, kwargs in read_rows.call_args_list + ] + eps = 0.01 + assert all( + (rpc_start_list[i] < eps for i in range(_CONCURRENCY_LIMIT)) + ) + for i in range(num_queries - _CONCURRENCY_LIMIT): + idx = i + _CONCURRENCY_LIMIT + assert rpc_start_list[idx] - i * increment_time < eps + + def test_read_rows_sharded_expirary(self): + """ + If the operation times out before all shards complete, should raise + a ShardedReadRowsExceptionGroup + """ + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded + + operation_timeout = 0.1 + num_queries = 15 + sleeps = [0] * _CONCURRENCY_LIMIT + [DeadlineExceeded("times up")] * ( + num_queries - _CONCURRENCY_LIMIT + ) + + def mock_call(*args, **kwargs): + next_item = sleeps.pop(0) + if isinstance(next_item, Exception): + raise next_item + else: + asyncio.sleep(next_item) + return [mock.Mock()] + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(num_queries)] + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + table.read_rows_sharded( + queries, operation_timeout=operation_timeout + ) + assert isinstance(exc.value, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == num_queries - _CONCURRENCY_LIMIT + assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT + + def test_read_rows_sharded_negative_batch_timeout(self): + """ + try to run with batch that starts after operation timeout + + They should raise DeadlineExceeded errors + """ + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded + + def mock_call(*args, **kwargs): + asyncio.sleep(0.05) + return [mock.Mock()] + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(15)] + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + table.read_rows_sharded(queries, operation_timeout=0.01) + assert isinstance(exc.value, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == 5 + assert all( + ( + isinstance(e.__cause__, DeadlineExceeded) + for e in exc.value.exceptions + ) + ) + + +class TestSampleRowKeys: + def _make_client(self, *args, **kwargs): + return TestBigtableDataClientAsync._make_client(*args, **kwargs) + + def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): + from google.cloud.bigtable_v2.types import SampleRowKeysResponse + + for value in sample_list: + yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) + + def test_sample_row_keys(self): + """Test that method returns the expected key samples""" + samples = [(b"test_1", 0), (b"test_2", 100), (b"test_3", 200)] + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", AsyncMock() + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream(samples) + result = table.sample_row_keys() + assert len(result) == 3 + assert all((isinstance(r, tuple) for r in result)) + assert all((isinstance(r[0], bytes) for r in result)) + assert all((isinstance(r[1], int) for r in result)) + assert result[0] == samples[0] + assert result[1] == samples[1] + assert result[2] == samples[2] + + def test_sample_row_keys_bad_timeout(self): + """should raise error if timeout is negative""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.sample_row_keys(operation_timeout=-1) + assert "operation_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + table.sample_row_keys(attempt_timeout=-1) + assert "attempt_timeout must be greater than 0" in str(e.value) + + def test_sample_row_keys_default_timeout(self): + """Should fallback to using table default operation_timeout""" + expected_timeout = 99 + with self._make_client() as client: + with client.get_table( + "i", + "t", + default_operation_timeout=expected_timeout, + default_attempt_timeout=expected_timeout, + ) as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", AsyncMock() + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream([]) + result = table.sample_row_keys() + _, kwargs = sample_row_keys.call_args + assert abs(kwargs["timeout"] - expected_timeout) < 0.1 + assert result == [] + assert kwargs["retry"] is None + + def test_sample_row_keys_gapic_params(self): + """make sure arguments are propagated to gapic call as expected""" + expected_timeout = 10 + expected_profile = "test1" + instance = "instance_name" + table_id = "my_table" + with self._make_client() as client: + with client.get_table( + instance, table_id, app_profile_id=expected_profile + ) as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", AsyncMock() + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream([]) + table.sample_row_keys(attempt_timeout=expected_timeout) + args, kwargs = sample_row_keys.call_args + assert len(args) == 0 + assert len(kwargs) == 5 + assert kwargs["timeout"] == expected_timeout + assert kwargs["app_profile_id"] == expected_profile + assert kwargs["table_name"] == table.table_name + assert kwargs["metadata"] is not None + assert kwargs["retry"] is None + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_sample_row_keys_retryable_errors(self, retryable_exception): + """retryable errors should be retried until timeout""" + from google.api_core.exceptions import DeadlineExceeded + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", AsyncMock() + ) as sample_row_keys: + sample_row_keys.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + table.sample_row_keys(operation_timeout=0.05) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert len(cause.exceptions) > 0 + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + core_exceptions.Aborted, + ], + ) + def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): + """non-retryable errors should cause a raise""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", AsyncMock() + ) as sample_row_keys: + sample_row_keys.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + table.sample_row_keys() + + +class TestTable: + def _make_client(self, *args, **kwargs): + return TestBigtableDataClientAsync._make_client(*args, **kwargs) + + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.data._async.client import TableAsync + + return TableAsync + + @property + def is_async(self): + return True + + def test_table_ctor(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_operation_timeout = 123 + expected_attempt_timeout = 12 + expected_read_rows_operation_timeout = 1.5 + expected_read_rows_attempt_timeout = 0.5 + expected_mutate_rows_operation_timeout = 2.5 + expected_mutate_rows_attempt_timeout = 0.75 + client = self._make_client() + assert not client._active_instances + table = self._get_target_class()( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + default_operation_timeout=expected_operation_timeout, + default_attempt_timeout=expected_attempt_timeout, + default_read_rows_operation_timeout=expected_read_rows_operation_timeout, + default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout, + default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, + ) + asyncio.sleep(0) + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + assert table.default_operation_timeout == expected_operation_timeout + assert table.default_attempt_timeout == expected_attempt_timeout + assert ( + table.default_read_rows_operation_timeout + == expected_read_rows_operation_timeout + ) + assert ( + table.default_read_rows_attempt_timeout + == expected_read_rows_attempt_timeout + ) + assert ( + table.default_mutate_rows_operation_timeout + == expected_mutate_rows_operation_timeout + ) + assert ( + table.default_mutate_rows_attempt_timeout + == expected_mutate_rows_attempt_timeout + ) + table._register_instance_future + assert table._register_instance_future.done() + assert not table._register_instance_future.cancelled() + assert table._register_instance_future.exception() is None + client.close() + + def test_table_ctor_defaults(self): + """should provide default timeout values and app_profile_id""" + from google.cloud.bigtable.data._async.client import TableAsync + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + client = self._make_client() + assert not client._active_instances + table = TableAsync(client, expected_instance_id, expected_table_id) + asyncio.sleep(0) + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id is None + assert table.client is client + assert table.default_operation_timeout == 60 + assert table.default_read_rows_operation_timeout == 600 + assert table.default_mutate_rows_operation_timeout == 600 + assert table.default_attempt_timeout == 20 + assert table.default_read_rows_attempt_timeout == 20 + assert table.default_mutate_rows_attempt_timeout == 60 + client.close() + + def test_table_ctor_invalid_timeout_values(self): + """bad timeout values should raise ValueError""" + from google.cloud.bigtable.data._async.client import TableAsync + + client = self._make_client() + timeout_pairs = [ + ("default_operation_timeout", "default_attempt_timeout"), + ( + "default_read_rows_operation_timeout", + "default_read_rows_attempt_timeout", + ), + ( + "default_mutate_rows_operation_timeout", + "default_mutate_rows_attempt_timeout", + ), + ] + for operation_timeout, attempt_timeout in timeout_pairs: + with pytest.raises(ValueError) as e: + TableAsync(client, "", "", **{attempt_timeout: -1}) + assert "attempt_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + TableAsync(client, "", "", **{operation_timeout: -1}) + assert "operation_timeout must be greater than 0" in str(e.value) + client.close() + + def test_table_ctor_sync(self): + from google.cloud.bigtable.data._async.client import TableAsync + + client = mock.Mock() + with pytest.raises(RuntimeError) as e: + TableAsync(client, "instance-id", "table-id") + assert e.match("TableAsync must be created within an async event loop context.") + + @pytest.mark.parametrize( + "fn_name,fn_args,is_stream,extra_retryables", + [ + ("read_rows_stream", (ReadRowsQuery(),), True, ()), + ("read_rows", (ReadRowsQuery(),), True, ()), + ("read_row", (b"row_key",), True, ()), + ("read_rows_sharded", ([ReadRowsQuery()],), True, ()), + ("row_exists", (b"row_key",), True, ()), + ("sample_row_keys", (), False, ()), + ("mutate_row", (b"row_key", [mock.Mock()]), False, ()), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + False, + (_MutateRowsIncomplete,), + ), + ], + ) + @pytest.mark.parametrize( + "input_retryables,expected_retryables", + [ + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), + ], + ) + def test_customizable_retryable_errors( + self, + input_retryables, + expected_retryables, + fn_name, + fn_args, + is_stream, + extra_retryables, + ): + """ + Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer. + """ + retry_fn = "retry_target" + if is_stream: + retry_fn += "_stream" + if self.is_async: + retry_fn = f"CrossSync.{retry_fn}" + else: + retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" + with mock.patch( + f"google.cloud.bigtable.data._sync.cross_sync.{retry_fn}" + ) as retry_fn_mock: + with self._make_client() as client: + table = client.get_table("instance-id", "table-id") + expected_predicate = lambda a: a in expected_retryables + retry_fn_mock.side_effect = RuntimeError("stop early") + with mock.patch( + "google.api_core.retry.if_exception_type" + ) as predicate_builder_mock: + predicate_builder_mock.return_value = expected_predicate + with pytest.raises(Exception): + test_fn = table.__getattribute__(fn_name) + test_fn(*fn_args, retryable_errors=input_retryables) + predicate_builder_mock.assert_called_once_with( + *expected_retryables, *extra_retryables + ) + retry_call_args = retry_fn_mock.call_args_list[0].args + assert retry_call_args[1] is expected_predicate + + @pytest.mark.parametrize( + "fn_name,fn_args,gapic_fn", + [ + ("read_rows_stream", (ReadRowsQuery(),), "read_rows"), + ("read_rows", (ReadRowsQuery(),), "read_rows"), + ("read_row", (b"row_key",), "read_rows"), + ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), + ("row_exists", (b"row_key",), "read_rows"), + ("sample_row_keys", (), "sample_row_keys"), + ("mutate_row", (b"row_key", [mock.Mock()]), "mutate_row"), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + "mutate_rows", + ), + ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), + ( + "read_modify_write_row", + (b"row_key", mock.Mock()), + "read_modify_write_row", + ), + ], + ) + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): + """check that all requests attach proper metadata headers""" + from google.cloud.bigtable.data import TableAsync + + profile = "profile" if include_app_profile else None + with mock.patch.object( + BigtableAsyncClient, gapic_fn, mock.AsyncMock() + ) as gapic_mock: + gapic_mock.side_effect = RuntimeError("stop early") + with self._make_client() as client: + table = TableAsync(client, "instance-id", "table-id", profile) + try: + test_fn = table.__getattribute__(fn_name) + maybe_stream = test_fn(*fn_args) + [i for i in maybe_stream] + except Exception: + pass + kwargs = gapic_mock.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py new file mode 100644 index 000000000..6d942f48c --- /dev/null +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -0,0 +1,1125 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by sync_surface_generator.py. Do not edit. + + +from mock import AsyncMock +from tests.unit.data._async.test_mutations_batcher import Test_FlowControl +from unittest import mock +from unittest.mock import AsyncMock +import asyncio +import mock +import pytest +import time + +from google.cloud.bigtable.data import TABLE_DEFAULT +from google.cloud.bigtable.data import TableAsync +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +import google.api_core.exceptions +import google.api_core.retry + + +class TestMutationsBatcher: + def _get_target_class(self): + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) + + return MutationsBatcherAsync + + @staticmethod + def is_async(): + return True + + def _make_one(self, table=None, **kwargs): + from google.api_core.exceptions import DeadlineExceeded + from google.api_core.exceptions import ServiceUnavailable + + if table is None: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 10 + table.default_mutate_rows_retryable_errors = ( + DeadlineExceeded, + ServiceUnavailable, + ) + return self._get_target_class()(table, **kwargs) + + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def test_ctor_defaults(self): + with mock.patch.object( + self._get_target_class(), "_timer_routine", return_value=asyncio.Future() + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = [Exception] + with self._make_one(table) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._max_mutation_count == 100000 + assert instance._flow_control._max_mutation_bytes == 104857600 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert ( + instance._operation_timeout + == table.default_mutate_rows_operation_timeout + ) + assert ( + instance._attempt_timeout + == table.default_mutate_rows_attempt_timeout + ) + assert ( + instance._retryable_errors + == table.default_mutate_rows_retryable_errors + ) + asyncio.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == 5 + assert isinstance(instance._flush_timer, asyncio.Future) + + def test_ctor_explicit(self): + """Test with explicit parameters""" + with mock.patch.object( + self._get_target_class(), "_timer_routine", return_value=asyncio.Future() + ) as flush_timer_mock: + table = mock.Mock() + flush_interval = 20 + flush_limit_count = 17 + flush_limit_bytes = 19 + flow_control_max_mutation_count = 1001 + flow_control_max_bytes = 12 + operation_timeout = 11 + attempt_timeout = 2 + retryable_errors = [Exception] + with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=operation_timeout, + batch_attempt_timeout=attempt_timeout, + batch_retryable_errors=retryable_errors, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert ( + instance._flow_control._max_mutation_count + == flow_control_max_mutation_count + ) + assert ( + instance._flow_control._max_mutation_bytes == flow_control_max_bytes + ) + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert instance._operation_timeout == operation_timeout + assert instance._attempt_timeout == attempt_timeout + assert instance._retryable_errors == retryable_errors + asyncio.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == flush_interval + assert isinstance(instance._flush_timer, asyncio.Future) + + def test_ctor_no_flush_limits(self): + """Test with None for flush limits""" + with mock.patch.object( + self._get_target_class(), "_timer_routine", return_value=asyncio.Future() + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = () + flush_interval = None + flush_limit_count = None + flush_limit_bytes = None + with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._staged_entries == [] + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + asyncio.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] is None + assert isinstance(instance._flush_timer, asyncio.Future) + + def test_ctor_invalid_values(self): + """Test that timeout values are positive, and fit within expected limits""" + with pytest.raises(ValueError) as e: + self._make_one(batch_operation_timeout=-1) + assert "operation_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._make_one(batch_attempt_timeout=-1) + assert "attempt_timeout must be greater than 0" in str(e.value) + + def test_default_argument_consistency(self): + """ + We supply default arguments in MutationsBatcherAsync.__init__, and in + table.mutations_batcher. Make sure any changes to defaults are applied to + both places + """ + import inspect + + get_batcher_signature = dict( + inspect.signature(TableAsync.mutations_batcher).parameters + ) + get_batcher_signature.pop("self") + batcher_init_signature = dict( + inspect.signature(self._get_target_class()).parameters + ) + batcher_init_signature.pop("table") + assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) + assert len(get_batcher_signature) == 8 + assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) + for arg_name in get_batcher_signature.keys(): + assert ( + get_batcher_signature[arg_name].default + == batcher_init_signature[arg_name].default + ) + + @pytest.mark.parametrize("input_val", [None, 0, -1]) + def test__start_flush_timer_w_empty_input(self, input_val): + """Empty/invalid timer should return immediately""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + with self._make_one() as instance: + if self.is_async(): + sleep_obj, sleep_method = (asyncio, "wait_for") + else: + sleep_obj, sleep_method = (instance._closed, "wait") + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + result = instance._timer_routine(input_val) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + assert result is None + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test__start_flush_timer_call_when_closed(self): + """closed batcher's timer should return immediately""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + with self._make_one() as instance: + instance.close() + flush_mock.reset_mock() + if self.is_async(): + sleep_obj, sleep_method = (asyncio, "wait_for") + else: + sleep_obj, sleep_method = (instance._closed, "wait") + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + instance._timer_routine(10) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + + @pytest.mark.parametrize("num_staged", [0, 1, 10]) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test__flush_timer(self, num_staged): + """Timer should continue to call _schedule_flush in a loop""" + from google.cloud.bigtable.data._sync.cross_sync import CrossSync + + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + expected_sleep = 12 + with self._make_one(flush_interval=expected_sleep) as instance: + loop_num = 3 + instance._staged_entries = [mock.Mock()] * num_staged + with mock.patch.object( + CrossSync._Sync_Impl, "event_wait" + ) as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] + with pytest.raises(TabError): + self._get_target_class()._timer_routine( + instance, expected_sleep + ) + instance._flush_timer = asyncio.Future() + assert sleep_mock.call_count == loop_num + 1 + sleep_kwargs = sleep_mock.call_args[1] + assert sleep_kwargs["timeout"] == expected_sleep + assert flush_mock.call_count == (0 if num_staged == 0 else loop_num) + + def test__flush_timer_close(self): + """Timer should continue terminate after close""" + with mock.patch.object(self._get_target_class(), "_schedule_flush"): + with self._make_one() as instance: + with mock.patch("asyncio.sleep"): + asyncio.sleep(0.5) + assert instance._flush_timer.done() is False + instance.close() + asyncio.sleep(0.1) + assert instance._flush_timer.done() is True + + def test_append_closed(self): + """Should raise exception""" + instance = self._make_one() + instance.close() + with pytest.raises(RuntimeError): + instance.append(mock.Mock()) + + def test_append_wrong_mutation(self): + """ + Mutation objects should raise an exception. + Only support RowMutationEntry + """ + from google.cloud.bigtable.data.mutations import DeleteAllFromRow + + with self._make_one() as instance: + expected_error = "invalid mutation type: DeleteAllFromRow. Only RowMutationEntry objects are supported by batcher" + with pytest.raises(ValueError) as e: + instance.append(DeleteAllFromRow()) + assert str(e.value) == expected_error + + def test_append_outside_flow_limits(self): + """entries larger than mutation limits are still processed""" + with self._make_one( + flow_control_max_mutation_count=1, flow_control_max_bytes=1 + ) as instance: + oversized_entry = self._make_mutation(count=0, size=2) + instance.append(oversized_entry) + assert instance._staged_entries == [oversized_entry] + assert instance._staged_count == 0 + assert instance._staged_bytes == 2 + instance._staged_entries = [] + with self._make_one( + flow_control_max_mutation_count=1, flow_control_max_bytes=1 + ) as instance: + overcount_entry = self._make_mutation(count=2, size=0) + instance.append(overcount_entry) + assert instance._staged_entries == [overcount_entry] + assert instance._staged_count == 2 + assert instance._staged_bytes == 0 + instance._staged_entries = [] + + def test_append_flush_runs_after_limit_hit(self): + """ + If the user appends a bunch of entries above the flush limits back-to-back, + it should still flush in a single task + """ + with mock.patch.object( + self._get_target_class(), "_execute_mutate_rows" + ) as op_mock: + with self._make_one(flush_limit_bytes=100) as instance: + + def mock_call(*args, **kwargs): + return [] + + op_mock.side_effect = mock_call + instance.append(self._make_mutation(size=99)) + num_entries = 10 + for _ in range(num_entries): + instance.append(self._make_mutation(size=1)) + instance._wait_for_batch_results(*instance._flush_jobs) + assert op_mock.call_count == 1 + sent_batch = op_mock.call_args[0][0] + assert len(sent_batch) == 2 + assert len(instance._staged_entries) == num_entries - 1 + + @pytest.mark.parametrize( + "flush_count,flush_bytes,mutation_count,mutation_bytes,expect_flush", + [ + (10, 10, 1, 1, False), + (10, 10, 9, 9, False), + (10, 10, 10, 1, True), + (10, 10, 1, 10, True), + (10, 10, 10, 10, True), + (1, 1, 10, 10, True), + (1, 1, 0, 0, False), + ], + ) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_append( + self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush + ): + """test appending different mutations, and checking if it causes a flush""" + with self._make_one( + flush_limit_mutation_count=flush_count, flush_limit_bytes=flush_bytes + ) as instance: + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert instance._staged_entries == [] + mutation = self._make_mutation(count=mutation_count, size=mutation_bytes) + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + instance.append(mutation) + assert flush_mock.call_count == bool(expect_flush) + assert instance._staged_count == mutation_count + assert instance._staged_bytes == mutation_bytes + assert instance._staged_entries == [mutation] + instance._staged_entries = [] + + def test_append_multiple_sequentially(self): + """Append multiple mutations""" + with self._make_one( + flush_limit_mutation_count=8, flush_limit_bytes=8 + ) as instance: + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert instance._staged_entries == [] + mutation = self._make_mutation(count=2, size=3) + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + instance.append(mutation) + assert flush_mock.call_count == 0 + assert instance._staged_count == 2 + assert instance._staged_bytes == 3 + assert len(instance._staged_entries) == 1 + instance.append(mutation) + assert flush_mock.call_count == 0 + assert instance._staged_count == 4 + assert instance._staged_bytes == 6 + assert len(instance._staged_entries) == 2 + instance.append(mutation) + assert flush_mock.call_count == 1 + assert instance._staged_count == 6 + assert instance._staged_bytes == 9 + assert len(instance._staged_entries) == 3 + instance._staged_entries = [] + + def test_flush_flow_control_concurrent_requests(self): + """requests should happen in parallel if flow control breaks up single flush into batches""" + import time + + num_calls = 10 + fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] + with self._make_one(flow_control_max_mutation_count=1) as instance: + with mock.patch.object( + instance, "_execute_mutate_rows", AsyncMock() + ) as op_mock: + + def mock_call(*args, **kwargs): + asyncio.sleep(0.1) + return [] + + op_mock.side_effect = mock_call + start_time = time.monotonic() + instance._staged_entries = fake_mutations + instance._schedule_flush() + asyncio.sleep(0.01) + for i in range(num_calls): + instance._flow_control.remove_from_flow( + [self._make_mutation(count=1)] + ) + asyncio.sleep(0.01) + instance._wait_for_batch_results(*instance._flush_jobs) + duration = time.monotonic() - start_time + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert duration < 0.5 + assert op_mock.call_count == num_calls + + def test_schedule_flush_no_mutations(self): + """schedule flush should return None if no staged mutations""" + with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal") as flush_mock: + for i in range(3): + assert instance._schedule_flush() is None + assert flush_mock.call_count == 0 + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_schedule_flush_with_mutations(self): + """if new mutations exist, should add a new flush task to _flush_jobs""" + with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal") as flush_mock: + if not self.is_async(): + flush_mock.side_effect = lambda x: time.sleep(0.1) + for i in range(1, 4): + mutation = mock.Mock() + instance._staged_entries = [mutation] + instance._schedule_flush() + assert instance._staged_entries == [] + asyncio.sleep(0) + assert instance._staged_entries == [] + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert flush_mock.call_count == 1 + flush_mock.reset_mock() + + def test__flush_internal(self): + """ + _flush_internal should: + - await previous flush call + - delegate batching to _flow_control + - call _execute_mutate_rows on each batch + - update self.exceptions and self._entries_processed_since_last_raise + """ + num_entries = 10 + with self._make_one() as instance: + with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: + with mock.patch.object( + instance._flow_control, "add_to_flow" + ) as flow_mock: + + def gen(x): + yield x + + flow_mock.side_effect = lambda x: gen(x) + mutations = [self._make_mutation(count=1, size=1)] * num_entries + instance._flush_internal(mutations) + assert instance._entries_processed_since_last_raise == num_entries + assert execute_mock.call_count == 1 + assert flow_mock.call_count == 1 + instance._oldest_exceptions.clear() + instance._newest_exceptions.clear() + + def test_flush_clears_job_list(self): + """ + a job should be added to _flush_jobs when _schedule_flush is called, + and removed when it completes + """ + with self._make_one() as instance: + with mock.patch.object( + instance, "_flush_internal", AsyncMock() + ) as flush_mock: + if not self.is_async(): + flush_mock.side_effect = lambda x: time.sleep(0.1) + mutations = [self._make_mutation(count=1, size=1)] + instance._staged_entries = mutations + assert instance._flush_jobs == set() + new_job = instance._schedule_flush() + assert instance._flush_jobs == {new_job} + if self.is_async(): + new_job + else: + new_job.result() + assert instance._flush_jobs == set() + + @pytest.mark.parametrize( + "num_starting,num_new_errors,expected_total_errors", + [ + (0, 0, 0), + (0, 1, 1), + (0, 2, 2), + (1, 0, 1), + (1, 1, 2), + (10, 2, 12), + (10, 20, 20), + ], + ) + def test__flush_internal_with_errors( + self, num_starting, num_new_errors, expected_total_errors + ): + """errors returned from _execute_mutate_rows should be added to internal exceptions""" + from google.cloud.bigtable.data import exceptions + + num_entries = 10 + expected_errors = [ + exceptions.FailedMutationEntryError(mock.Mock(), mock.Mock(), ValueError()) + ] * num_new_errors + with self._make_one() as instance: + instance._oldest_exceptions = [mock.Mock()] * num_starting + with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: + execute_mock.return_value = expected_errors + with mock.patch.object( + instance._flow_control, "add_to_flow" + ) as flow_mock: + + def gen(x): + yield x + + flow_mock.side_effect = lambda x: gen(x) + mutations = [self._make_mutation(count=1, size=1)] * num_entries + instance._flush_internal(mutations) + assert instance._entries_processed_since_last_raise == num_entries + assert execute_mock.call_count == 1 + assert flow_mock.call_count == 1 + found_exceptions = instance._oldest_exceptions + list( + instance._newest_exceptions + ) + assert len(found_exceptions) == expected_total_errors + for i in range(num_starting, expected_total_errors): + assert found_exceptions[i] == expected_errors[i - num_starting] + assert found_exceptions[i].index is None + instance._oldest_exceptions.clear() + instance._newest_exceptions.clear() + + def _mock_gapic_return(self, num=5): + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc import status_pb2 + + def gen(num): + for i in range(num): + entry = MutateRowsResponse.Entry( + index=i, status=status_pb2.Status(code=0) + ) + yield MutateRowsResponse(entries=[entry]) + + return gen(num) + + def test_timer_flush_end_to_end(self): + """Flush should automatically trigger after flush_interval""" + num_nutations = 10 + mutations = [self._make_mutation(count=2, size=2)] * num_nutations + with self._make_one(flush_interval=0.05) as instance: + instance._table.default_operation_timeout = 10 + instance._table.default_attempt_timeout = 9 + with mock.patch.object( + instance._table.client._gapic_client, "mutate_rows" + ) as gapic_mock: + gapic_mock.side_effect = ( + lambda *args, **kwargs: self._mock_gapic_return(num_nutations) + ) + for m in mutations: + instance.append(m) + assert instance._entries_processed_since_last_raise == 0 + asyncio.sleep(0.1) + assert instance._entries_processed_since_last_raise == num_nutations + + def test__execute_mutate_rows(self): + if self.is_async(): + mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" + else: + mutate_path = "_sync._mutate_rows._MutateRowsOperation" + with mock.patch(f"google.cloud.bigtable.data.{mutate_path}") as mutate_rows: + mutate_rows.return_value = AsyncMock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = instance._execute_mutate_rows(batch) + assert start_operation.call_count == 1 + args, kwargs = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + kwargs["operation_timeout"] == 17 + kwargs["attempt_timeout"] == 13 + assert result == [] + + def test__execute_mutate_rows_returns_errors(self): + """Errors from operation should be retruned as list""" + from google.cloud.bigtable.data.exceptions import ( + MutationsExceptionGroup, + FailedMutationEntryError, + ) + + if self.is_async(): + mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" + else: + mutate_path = "_sync._mutate_rows._MutateRowsOperation" + with mock.patch( + f"google.cloud.bigtable.data.{mutate_path}.start" + ) as mutate_rows: + err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) + err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + assert result[0].index is None + assert result[1].index is None + + def test__raise_exceptions(self): + """Raise exceptions and reset error state""" + from google.cloud.bigtable.data import exceptions + + expected_total = 1201 + expected_exceptions = [RuntimeError("mock")] * 3 + with self._make_one() as instance: + instance._oldest_exceptions = expected_exceptions + instance._entries_processed_since_last_raise = expected_total + try: + instance._raise_exceptions() + except exceptions.MutationsExceptionGroup as exc: + assert list(exc.exceptions) == expected_exceptions + assert str(expected_total) in str(exc) + assert instance._entries_processed_since_last_raise == 0 + instance._oldest_exceptions, instance._newest_exceptions = ([], []) + instance._raise_exceptions() + + def test___aenter__(self): + """Should return self""" + with self._make_one() as instance: + assert instance.__aenter__() == instance + + def test___aexit__(self): + """aexit should call close""" + with self._make_one() as instance: + with mock.patch.object(instance, "close") as close_mock: + instance.__aexit__(None, None, None) + assert close_mock.call_count == 1 + + def test_close(self): + """Should clean up all resources""" + with self._make_one() as instance: + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + with mock.patch.object(instance, "_raise_exceptions") as raise_mock: + instance.close() + assert instance.closed is True + assert instance._flush_timer.done() is True + assert instance._flush_jobs == set() + assert flush_mock.call_count == 1 + assert raise_mock.call_count == 1 + + def test_close_w_exceptions(self): + """Raise exceptions on close""" + from google.cloud.bigtable.data import exceptions + + expected_total = 10 + expected_exceptions = [RuntimeError("mock")] + with self._make_one() as instance: + instance._oldest_exceptions = expected_exceptions + instance._entries_processed_since_last_raise = expected_total + try: + instance.close() + except exceptions.MutationsExceptionGroup as exc: + assert list(exc.exceptions) == expected_exceptions + assert str(expected_total) in str(exc) + assert instance._entries_processed_since_last_raise == 0 + instance._oldest_exceptions, instance._newest_exceptions = ([], []) + + def test__on_exit(self, recwarn): + """Should raise warnings if unflushed mutations exist""" + with self._make_one() as instance: + instance._on_exit() + assert len(recwarn) == 0 + num_left = 4 + instance._staged_entries = [mock.Mock()] * num_left + with pytest.warns(UserWarning) as w: + instance._on_exit() + assert len(w) == 1 + assert "unflushed mutations" in str(w[0].message).lower() + assert str(num_left) in str(w[0].message) + instance._closed.set() + instance._on_exit() + assert len(recwarn) == 0 + instance._staged_entries = [] + + def test_atexit_registration(self): + """Should run _on_exit on program termination""" + import atexit + + with mock.patch.object(atexit, "register") as register_mock: + assert register_mock.call_count == 0 + with self._make_one(): + assert register_mock.call_count == 1 + + def test_timeout_args_passed(self): + """ + batch_operation_timeout and batch_attempt_timeout should be used + in api calls + """ + if self.is_async(): + mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" + else: + mutate_path = "_sync._mutate_rows._MutateRowsOperation" + with mock.patch( + f"google.cloud.bigtable.data.{mutate_path}", return_value=AsyncMock() + ) as mutate_rows: + expected_operation_timeout = 17 + expected_attempt_timeout = 13 + with self._make_one( + batch_operation_timeout=expected_operation_timeout, + batch_attempt_timeout=expected_attempt_timeout, + ) as instance: + assert instance._operation_timeout == expected_operation_timeout + assert instance._attempt_timeout == expected_attempt_timeout + instance._execute_mutate_rows([self._make_mutation()]) + assert mutate_rows.call_count == 1 + kwargs = mutate_rows.call_args[1] + assert kwargs["operation_timeout"] == expected_operation_timeout + assert kwargs["attempt_timeout"] == expected_attempt_timeout + + @pytest.mark.parametrize( + "limit,in_e,start_e,end_e", + [ + (10, 0, (10, 0), (10, 0)), + (1, 10, (0, 0), (1, 1)), + (10, 1, (0, 0), (1, 0)), + (10, 10, (0, 0), (10, 0)), + (10, 11, (0, 0), (10, 1)), + (3, 20, (0, 0), (3, 3)), + (10, 20, (0, 0), (10, 10)), + (10, 21, (0, 0), (10, 10)), + (2, 1, (2, 0), (2, 1)), + (2, 1, (1, 0), (2, 0)), + (2, 2, (1, 0), (2, 1)), + (3, 1, (3, 1), (3, 2)), + (3, 3, (3, 1), (3, 3)), + (1000, 5, (999, 0), (1000, 4)), + (1000, 5, (0, 0), (5, 0)), + (1000, 5, (1000, 0), (1000, 5)), + ], + ) + def test__add_exceptions(self, limit, in_e, start_e, end_e): + """ + Test that the _add_exceptions function properly updates the + _oldest_exceptions and _newest_exceptions lists + Args: + - limit: the _exception_list_limit representing the max size of either list + - in_e: size of list of exceptions to send to _add_exceptions + - start_e: a tuple of ints representing the initial sizes of _oldest_exceptions and _newest_exceptions + - end_e: a tuple of ints representing the expected sizes of _oldest_exceptions and _newest_exceptions + """ + from collections import deque + + input_list = [RuntimeError(f"mock {i}") for i in range(in_e)] + mock_batcher = mock.Mock() + mock_batcher._oldest_exceptions = [ + RuntimeError(f"starting mock {i}") for i in range(start_e[0]) + ] + mock_batcher._newest_exceptions = deque( + [RuntimeError(f"starting mock {i}") for i in range(start_e[1])], + maxlen=limit, + ) + mock_batcher._exception_list_limit = limit + mock_batcher._exceptions_since_last_raise = 0 + self._get_target_class()._add_exceptions(mock_batcher, input_list) + assert len(mock_batcher._oldest_exceptions) == end_e[0] + assert len(mock_batcher._newest_exceptions) == end_e[1] + assert mock_batcher._exceptions_since_last_raise == in_e + oldest_list_diff = end_e[0] - start_e[0] + newest_list_diff = min(max(in_e - oldest_list_diff, 0), limit) + for i in range(oldest_list_diff): + assert mock_batcher._oldest_exceptions[i + start_e[0]] == input_list[i] + for i in range(1, newest_list_diff + 1): + assert mock_batcher._newest_exceptions[-i] == input_list[-i] + + @pytest.mark.parametrize( + "input_retryables,expected_retryables", + [ + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), + ], + ) + def test_customizable_retryable_errors(self, input_retryables, expected_retryables): + """ + Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer. + """ + retryn_fn = ( + "google.cloud.bigtable.data._sync.cross_sync.CrossSync.retry_target" + if "Async" in self._get_target_class().__name__ + else "google.api_core.retry.retry_target" + ) + with mock.patch.object( + google.api_core.retry, "if_exception_type" + ) as predicate_builder_mock: + with mock.patch(retryn_fn) as retry_fn_mock: + table = None + with mock.patch("asyncio.create_task"): + table = TableAsync(mock.Mock(), "instance", "table") + with self._make_one( + table, batch_retryable_errors=input_retryables + ) as instance: + assert instance._retryable_errors == expected_retryables + expected_predicate = lambda a: a in expected_retryables + predicate_builder_mock.return_value = expected_predicate + retry_fn_mock.side_effect = RuntimeError("stop early") + mutation = self._make_mutation(count=1, size=1) + instance._execute_mutate_rows([mutation]) + predicate_builder_mock.assert_called_once_with( + *expected_retryables, _MutateRowsIncomplete + ) + retry_call_args = retry_fn_mock.call_args_list[0].args + assert retry_call_args[1] is expected_predicate + + +class Test_FlowControl: + @staticmethod + def _target_class(): + from google.cloud.bigtable.data._async.mutations_batcher import ( + _FlowControlAsync, + ) + + return _FlowControlAsync + + def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): + return self._target_class()(max_mutation_count, max_mutation_bytes) + + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def test_ctor(self): + max_mutation_count = 9 + max_mutation_bytes = 19 + instance = self._make_one(max_mutation_count, max_mutation_bytes) + assert instance._max_mutation_count == max_mutation_count + assert instance._max_mutation_bytes == max_mutation_bytes + assert instance._in_flight_mutation_count == 0 + assert instance._in_flight_mutation_bytes == 0 + assert isinstance(instance._capacity_condition, asyncio.Condition) + + def test_ctor_invalid_values(self): + """Test that values are positive, and fit within expected limits""" + with pytest.raises(ValueError) as e: + self._make_one(0, 1) + assert "max_mutation_count must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._make_one(1, 0) + assert "max_mutation_bytes must be greater than 0" in str(e.value) + + @pytest.mark.parametrize( + "max_count,max_size,existing_count,existing_size,new_count,new_size,expected", + [ + (1, 1, 0, 0, 0, 0, True), + (1, 1, 1, 1, 1, 1, False), + (10, 10, 0, 0, 0, 0, True), + (10, 10, 0, 0, 9, 9, True), + (10, 10, 0, 0, 11, 9, True), + (10, 10, 0, 1, 11, 9, True), + (10, 10, 1, 0, 11, 9, False), + (10, 10, 0, 0, 9, 11, True), + (10, 10, 1, 0, 9, 11, True), + (10, 10, 0, 1, 9, 11, False), + (10, 1, 0, 0, 1, 0, True), + (1, 10, 0, 0, 0, 8, True), + (float("inf"), float("inf"), 0, 0, 10000000000.0, 10000000000.0, True), + (8, 8, 0, 0, 10000000000.0, 10000000000.0, True), + (12, 12, 6, 6, 5, 5, True), + (12, 12, 5, 5, 6, 6, True), + (12, 12, 6, 6, 6, 6, True), + (12, 12, 6, 6, 7, 7, False), + (12, 12, 0, 0, 13, 13, True), + (12, 12, 12, 0, 0, 13, True), + (12, 12, 0, 12, 13, 0, True), + (12, 12, 1, 1, 13, 13, False), + (12, 12, 1, 1, 0, 13, False), + (12, 12, 1, 1, 13, 0, False), + ], + ) + def test__has_capacity( + self, + max_count, + max_size, + existing_count, + existing_size, + new_count, + new_size, + expected, + ): + """_has_capacity should return True if the new mutation will will not exceed the max count or size""" + instance = self._make_one(max_count, max_size) + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + assert instance._has_capacity(new_count, new_size) == expected + + @pytest.mark.parametrize( + "existing_count,existing_size,added_count,added_size,new_count,new_size", + [ + (0, 0, 0, 0, 0, 0), + (2, 2, 1, 1, 1, 1), + (2, 0, 1, 0, 1, 0), + (0, 2, 0, 1, 0, 1), + (10, 10, 0, 0, 10, 10), + (10, 10, 5, 5, 5, 5), + (0, 0, 1, 1, -1, -1), + ], + ) + def test_remove_from_flow_value_update( + self, + existing_count, + existing_size, + added_count, + added_size, + new_count, + new_size, + ): + """completed mutations should lower the inflight values""" + instance = self._make_one() + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + mutation = self._make_mutation(added_count, added_size) + instance.remove_from_flow(mutation) + assert instance._in_flight_mutation_count == new_count + assert instance._in_flight_mutation_bytes == new_size + + def test__remove_from_flow_unlock(self): + """capacity condition should notify after mutation is complete""" + import inspect + + instance = self._make_one(10, 10) + instance._in_flight_mutation_count = 10 + instance._in_flight_mutation_bytes = 10 + + def task_routine(): + with instance._capacity_condition: + instance._capacity_condition.wait_for( + lambda: instance._has_capacity(1, 1) + ) + + if inspect.iscoroutinefunction(task_routine): + task = asyncio.create_task(task_routine()) + task_alive = lambda: not task.done() + else: + import threading + + thread = threading.Thread(target=task_routine) + thread.start() + task_alive = thread.is_alive + asyncio.sleep(0.05) + assert task_alive() is True + mutation = self._make_mutation(count=0, size=5) + instance.remove_from_flow([mutation]) + asyncio.sleep(0.05) + assert instance._in_flight_mutation_count == 10 + assert instance._in_flight_mutation_bytes == 5 + assert task_alive() is True + instance._in_flight_mutation_bytes = 10 + mutation = self._make_mutation(count=5, size=0) + instance.remove_from_flow([mutation]) + asyncio.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 10 + assert task_alive() is True + instance._in_flight_mutation_count = 10 + mutation = self._make_mutation(count=5, size=5) + instance.remove_from_flow([mutation]) + asyncio.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 5 + assert task_alive() is False + + @pytest.mark.parametrize( + "mutations,count_cap,size_cap,expected_results", + [ + ([(5, 5), (1, 1), (1, 1)], 10, 10, [[(5, 5), (1, 1), (1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 1, 1, [[(1, 1)], [(1, 1)], [(1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 2, 10, [[(1, 1), (1, 1)], [(1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 10, 2, [[(1, 1), (1, 1)], [(1, 1)]]), + ( + [(1, 1), (5, 5), (4, 1), (1, 4), (1, 1)], + 5, + 5, + [[(1, 1)], [(5, 5)], [(4, 1), (1, 4)], [(1, 1)]], + ), + ], + ) + def test_add_to_flow(self, mutations, count_cap, size_cap, expected_results): + """Test batching with various flow control settings""" + mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] + instance = self._make_one(count_cap, size_cap) + i = 0 + for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + assert len(batch[j].mutations) == expected_batch[j][0] + assert batch[j].size() == expected_batch[j][1] + instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + @pytest.mark.parametrize( + "mutations,max_limit,expected_results", + [ + ([(1, 1)] * 11, 10, [[(1, 1)] * 10, [(1, 1)]]), + ([(1, 1)] * 10, 1, [[(1, 1)] for _ in range(10)]), + ([(1, 1)] * 10, 2, [[(1, 1), (1, 1)] for _ in range(5)]), + ], + ) + def test_add_to_flow_max_mutation_limits( + self, mutations, max_limit, expected_results + ): + """ + Test flow control running up against the max API limit + Should submit request early, even if the flow control has room for more + """ + async_patch = mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + max_limit, + ) + sync_patch = mock.patch( + "google.cloud.bigtable.data.mutations._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + max_limit, + ) + with async_patch, sync_patch: + mutation_objs = [ + self._make_mutation(count=m[0], size=m[1]) for m in mutations + ] + instance = self._make_one(float("inf"), float("inf")) + i = 0 + for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + assert len(batch[j].mutations) == expected_batch[j][0] + assert batch[j].size() == expected_batch[j][1] + instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + def test_add_to_flow_oversize(self): + """mutations over the flow control limits should still be accepted""" + instance = self._make_one(2, 3) + large_size_mutation = self._make_mutation(count=1, size=10) + large_count_mutation = self._make_mutation(count=10, size=1) + results = [out for out in instance.add_to_flow([large_size_mutation])] + assert len(results) == 1 + instance.remove_from_flow(results[0]) + count_results = [out for out in instance.add_to_flow(large_count_mutation)] + assert len(count_results) == 1 diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync/test_read_rows_acceptance.py new file mode 100644 index 000000000..9d6df6f9c --- /dev/null +++ b/tests/unit/data/_sync/test_read_rows_acceptance.py @@ -0,0 +1,328 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by sync_surface_generator.py. Do not edit. + + +from __future__ import annotations +from itertools import zip_longest +from tests.unit.v2_client.test_row_merger import ReadRowsTest +from tests.unit.v2_client.test_row_merger import TestFile +import mock +import os +import pytest +import warnings + +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.row import Row +from google.cloud.bigtable_v2 import ReadRowsResponse + + +class TestReadRowsAcceptance: + @staticmethod + def _get_operation_class(): + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + + return _ReadRowsOperationAsync + + @staticmethod + def _get_client_class(): + from google.cloud.bigtable.data._async.client import BigtableDataClientAsync + + return BigtableDataClientAsync + + def parse_readrows_acceptance_tests(): + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "../read-rows-acceptance-test.json") + with open(filename) as json_file: + test_json = TestFile.from_json(json_file.read()) + return test_json.read_rows_tests + + @staticmethod + def extract_results_from_row(row: Row): + results = [] + for family, col, cells in row.items(): + for cell in cells: + results.append( + ReadRowsTest.Result( + row_key=row.row_key, + family_name=family, + qualifier=col, + timestamp_micros=cell.timestamp_ns // 1000, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + ) + return results + + @staticmethod + def _coro_wrapper(stream): + return stream + + def _process_chunks(self, *chunks): + def _row_stream(): + yield ReadRowsResponse(chunks=chunks) + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + results = [] + for row in merger: + results.append(row) + return results + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + def test_row_merger_scenario(self, test_case: ReadRowsTest): + def _scenerio_stream(): + for chunk in test_case.chunks: + yield ReadRowsResponse(chunks=[chunk]) + + try: + results = [] + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_scenerio_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + for row in merger: + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + def test_read_rows_scenario(self, test_case: ReadRowsTest): + def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list): + self.chunk_list = chunk_list + self.idx = -1 + + def __aiter__(self): + return self + + def __anext__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + chunk = self.chunk_list[self.idx] + return ReadRowsResponse(chunks=[chunk]) + raise StopAsyncIteration + + def cancel(self): + pass + + return mock_stream(chunk_list) + + with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + client = self._get_client_class()() + try: + table = client.get_table("instance", "table") + results = [] + with mock.patch.object( + table.client._gapic_client, "read_rows" + ) as read_rows: + read_rows.return_value = _make_gapic_stream(test_case.chunks) + for row in table.read_rows_stream(query={}): + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + finally: + client.close() + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + def test_out_of_order_rows(self): + def _row_stream(): + yield ReadRowsResponse(last_scanned_row_key=b"a") + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = b"b" + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + with pytest.raises(InvalidChunk): + for _ in merger: + pass + + def test_bare_reset(self): + first_chunk = ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk( + row_key=b"a", family_name="f", qualifier=b"q", value=b"v" + ) + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, family_name="f") + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, value=b"v") + ), + ) + + def test_missing_family(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + qualifier=b"q", + timestamp_micros=1000, + value=b"v", + commit_row=True, + ) + ) + + def test_mid_cell_row_key_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), + ) + + def test_mid_cell_family_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + family_name="f2", value=b"v", commit_row=True + ), + ) + + def test_mid_cell_qualifier_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + qualifier=b"q2", value=b"v", commit_row=True + ), + ) + + def test_mid_cell_timestamp_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + timestamp_micros=2000, value=b"v", commit_row=True + ), + ) + + def test_mid_cell_labels_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), + ) From 6ae2428b18f43877a9a08a5bdeebd7d4e9cb0eb8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 24 Jun 2024 14:48:02 -0700 Subject: [PATCH 099/360] improved import generation --- .../bigtable/data/_async/_mutate_rows.py | 2 +- .../cloud/bigtable/data/_sync/_mutate_rows.py | 16 ++++--- .../cloud/bigtable/data/_sync/_read_rows.py | 6 +-- google/cloud/bigtable/data/_sync/client.py | 40 ++++++++--------- .../bigtable/data/_sync/mutations_batcher.py | 14 ++++-- sync_surface_generator.py | 45 ++++--------------- tests/unit/data/_async/test_client.py | 21 ++++++--- tests/unit/data/_sync/test__mutate_rows.py | 12 ++--- tests/unit/data/_sync/test__read_rows.py | 8 ++-- tests/unit/data/_sync/test_client.py | 32 +++++++------ .../unit/data/_sync/test_mutations_batcher.py | 14 +++--- .../data/_sync/test_read_rows_acceptance.py | 3 +- 12 files changed, 102 insertions(+), 111 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 8f1e64f61..439b2629c 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -47,7 +47,7 @@ from google.cloud.bigtable_v2.services.bigtable.client import ( # noqa: F401 BigtableClient, ) - + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto @dataclass class _EntryWithProto: diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index 37849eca4..810f29661 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -16,24 +16,30 @@ from __future__ import annotations -from typing import Sequence +from typing import Sequence, TYPE_CHECKING import functools from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries -from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory -from google.cloud.bigtable.data._sync.client import Table from google.cloud.bigtable.data._sync.cross_sync import CrossSync from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.exceptions import RetryExceptionGroup from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete -from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT -from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient + +if TYPE_CHECKING: + from google.cloud.bigtable.data.mutations import RowMutationEntry + + if CrossSync.is_async: + pass + else: + from google.cloud.bigtable.data._sync.client import Table + from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto class _MutateRowsOperation: diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index 9e702e1d6..f46b80f4c 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -16,20 +16,16 @@ from __future__ import annotations -from typing import Iterable from typing import Sequence from google.api_core import retry as retries from google.api_core.retry import exponential_sleep_generator from google.cloud.bigtable.data import _helpers -from google.cloud.bigtable.data._async._read_rows import _ResetRow -from google.cloud.bigtable.data._sync.client import Table from google.cloud.bigtable.data._sync.cross_sync import CrossSync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _RowSetComplete from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery -from google.cloud.bigtable.data.row import Cell -from google.cloud.bigtable.data.row import Row +from google.cloud.bigtable.data.row import Row, Cell from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB from google.cloud.bigtable_v2.types import RowRange as RowRangePB diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index bf2193fd4..260ef1e22 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -18,12 +18,7 @@ from __future__ import annotations from functools import partial from grpc import Channel -from typing import Any -from typing import Iterable -from typing import Optional -from typing import Sequence -from typing import Set -from typing import cast +from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING import asyncio import concurrent.futures import os @@ -38,18 +33,12 @@ from google.api_core.exceptions import ServiceUnavailable from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT from google.cloud.bigtable.data import _helpers -from google.cloud.bigtable.data._helpers import RowKeySamples -from google.cloud.bigtable.data._helpers import ShardedQuery from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import _MB_SIZE -from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation -from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation from google.cloud.bigtable.data._sync.cross_sync import CrossSync -from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher from google.cloud.bigtable.data.exceptions import FailedQueryShardError from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup -from google.cloud.bigtable.data.mutations import Mutation -from google.cloud.bigtable.data.mutations import RowMutationEntry +from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.row import Row @@ -57,20 +46,31 @@ from google.cloud.bigtable.data.row_filters import RowFilter from google.cloud.bigtable.data.row_filters import RowFilterChain from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter -from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta from google.cloud.bigtable_v2.services.bigtable.transports.base import ( DEFAULT_CLIENT_INFO, ) -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledBigtableGrpcTransport, -) -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledChannel, -) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest from google.cloud.client import ClientWithProject from google.cloud.environment_vars import BIGTABLE_EMULATOR + +if CrossSync.is_async: + pass +else: + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation + from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledBigtableGrpcTransport, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledChannel, + ) + from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient + from typing import Iterable +if TYPE_CHECKING: + from google.cloud.bigtable.data._helpers import RowKeySamples + from google.cloud.bigtable.data._helpers import ShardedQuery import google.auth._default import google.auth.credentials diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index f3644e62b..c750b8b78 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -17,7 +17,7 @@ from __future__ import annotations from collections import deque -from typing import Sequence +from typing import Sequence, TYPE_CHECKING import atexit import concurrent.futures import warnings @@ -26,15 +26,21 @@ from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts -from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation -from google.cloud.bigtable.data._sync.client import Table from google.cloud.bigtable.data._sync.cross_sync import CrossSync from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.mutations import Mutation -from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +if TYPE_CHECKING: + from google.cloud.bigtable.data.mutations import RowMutationEntry + + if CrossSync.is_async: + pass + else: + from google.cloud.bigtable.data._sync.client import Table + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation + class MutationsBatcher: """ diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 25ba85a0d..d8fbc87fa 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -259,35 +259,17 @@ def _create_error_node(node, error_msg): def get_imports(self, filename): """ - Get the imports from a file, and do a find-and-replace against asyncio_replacements + Extract all imports from file root + + Include if statements that contain imports """ - imports = set() with open(filename, "r") as f: full_tree = ast.parse(f.read(), filename) - for node in ast.walk(full_tree): - if isinstance(node, (ast.Import, ast.ImportFrom)): - for alias in node.names: - if isinstance(node, ast.Import): - # import statments - new_import = self.asyncio_replacements.get(alias.name, alias.name) - imports.add(ast.parse(f"import {new_import}").body[0]) - else: - # import from statements - # break into individual components - full_path = f"{node.module}.{alias.name}" - if full_path in self.asyncio_replacements: - full_path = self.asyncio_replacements[full_path] - module, name = full_path.rsplit(".", 1) - # don't import from same file - if module == ".": - continue - asname_str = f" as {alias.asname}" if alias.asname else "" - imports.add( - ast.parse(f"from {module} import {name}{asname_str}").body[ - 0 - ] - ) - return imports + imports = [node for node in full_tree.body if isinstance(node, (ast.Import, ast.ImportFrom))] + if_imports = [node for node in full_tree.body if isinstance(node, ast.If) and any(isinstance(n, (ast.Import, ast.ImportFrom)) for n in node.body)] + try_imports = [node for node in full_tree.body if isinstance(node, ast.Try)] + return set(imports + if_imports + try_imports) + def transform_class(in_obj: Type, **kwargs): filename = inspect.getfile(in_obj) @@ -316,17 +298,6 @@ def transform_class(in_obj: Type, **kwargs): transformer.visit(ast_tree) # find imports imports = transformer.get_imports(filename) - # imports.add(ast.parse("from abc import ABC").body[0]) - # add locals from file, in case they are needed - if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): - with open(filename, "r") as f: - for node in ast.walk(ast.parse(f.read(), filename)): - if isinstance(node, ast.ClassDef): - imports.add( - ast.parse( - f"from {in_obj.__module__} import {node.name}" - ).body[0] - ) return ast_tree.body, imports diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 8af98c376..66bba4b97 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -59,9 +59,14 @@ class TestBigtableDataClientAsync: @staticmethod def _get_target_class(): - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync + if CrossSync.is_async: + from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - return BigtableDataClientAsync + return BigtableDataClientAsync + else: + from google.cloud.bigtable.data._sync.client import BigtableDataClient + + return BigtableDataClient @classmethod def _make_client(cls, *args, use_emulator=True, **kwargs): @@ -337,8 +342,6 @@ async def test__ping_and_warm_instances(self): """ test ping and warm with mocked asyncio.gather """ - from google.cloud.bigtable.data._sync.cross_sync import CrossSync - client_mock = mock.Mock() client_mock._execute_ping_and_warms = ( lambda *args: self._get_target_class()._execute_ping_and_warms( @@ -1049,8 +1052,6 @@ async def test_close(self): @pytest.mark.asyncio async def test_close_with_timeout(self): - from google.cloud.bigtable.data._sync.cross_sync import CrossSync - pool_size = 7 expected_timeout = 19 client = self._make_client(project="project-id", pool_size=pool_size) @@ -1097,6 +1098,7 @@ def test_client_ctor_sync(self): @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestTable", + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, ) class TestTableAsync: def _make_client(self, *args, **kwargs): @@ -1418,6 +1420,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestReadRows", + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, ) class TestReadRowsAsync: """ @@ -1927,6 +1930,7 @@ async def test_row_exists(self, return_value, expected_result): @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestReadRowsSharded", + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, ) class TestReadRowsShardedAsync: def _make_client(self, *args, **kwargs): @@ -2150,6 +2154,7 @@ async def mock_call(*args, **kwargs): @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestSampleRowKeys", + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, ) class TestSampleRowKeysAsync: def _make_client(self, *args, **kwargs): @@ -2304,6 +2309,7 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestMutateRow", + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, ) class TestMutateRowAsync: def _make_client(self, *args, **kwargs): @@ -2482,6 +2488,7 @@ async def test_mutate_row_no_mutations(self, mutations): @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestBulkMutateRows", + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, ) class TestBulkMutateRowsAsync: def _make_client(self, *args, **kwargs): @@ -2864,6 +2871,7 @@ async def test_bulk_mutate_error_recovery(self): @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestCheckAndMutateRow", + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, ) class TestCheckAndMutateRowAsync: def _make_client(self, *args, **kwargs): @@ -3018,6 +3026,7 @@ async def test_check_and_mutate_mutations_parsing(self): @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestReadModifyWriteRow", + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, ) class TestReadModifyWriteRowAsync: def _make_client(self, *args, **kwargs): diff --git a/tests/unit/data/_sync/test__mutate_rows.py b/tests/unit/data/_sync/test__mutate_rows.py index bbe6fbb22..f4fc2d279 100644 --- a/tests/unit/data/_sync/test__mutate_rows.py +++ b/tests/unit/data/_sync/test__mutate_rows.py @@ -15,13 +15,15 @@ # This file is automatically generated by sync_surface_generator.py. Do not edit. -from mock import AsyncMock -# from tests.unit.data._async.test__mutate_rows import TestMutateRowsOperation -from unittest import mock -from unittest.mock import AsyncMock -import mock import pytest +try: + from unittest import mock + from unittest.mock import AsyncMock +except ImportError: + import mock + from mock import AsyncMock + from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import Forbidden from google.cloud.bigtable.data._sync.cross_sync import CrossSync diff --git a/tests/unit/data/_sync/test__read_rows.py b/tests/unit/data/_sync/test__read_rows.py index 0cda06432..78b61cebc 100644 --- a/tests/unit/data/_sync/test__read_rows.py +++ b/tests/unit/data/_sync/test__read_rows.py @@ -15,11 +15,13 @@ # This file is automatically generated by sync_surface_generator.py. Do not edit. -from tests.unit.data._async.test__read_rows import TestReadRowsOperation -from unittest import mock -import mock import pytest +try: + from unittest import mock +except ImportError: + import mock + class TestReadRowsOperation: """ diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 4feafd373..b05bc94bf 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -16,30 +16,30 @@ from __future__ import annotations -from mock import AsyncMock -from tests.unit.data._async.test_client import TestBigtableDataClientAsync -from tests.unit.data._async.test_client import TestReadRowsAsync -from tests.unit.data._async.test_client import TestTableAsync -from unittest import mock -from unittest.mock import AsyncMock import asyncio import grpc -import mock import pytest import re import sys +try: + from unittest import mock + from unittest.mock import AsyncMock +except ImportError: + import mock + from mock import AsyncMock + from google.api_core import exceptions as core_exceptions from google.api_core import grpc_helpers_async from google.auth.credentials import AnonymousCredentials from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data import mutations +from google.cloud.bigtable.data._sync.cross_sync import CrossSync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery -from google.cloud.bigtable_v2 import ReadRowsResponse from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( PooledBigtableGrpcAsyncIOTransport, @@ -53,11 +53,13 @@ class TestBigtableDataClient: @staticmethod def _get_target_class(): - if CrossSync.is_async: + if CrossSync._Sync_Impl.is_async: from google.cloud.bigtable.data._async.client import BigtableDataClientAsync + return BigtableDataClientAsync else: from google.cloud.bigtable.data._sync.client import BigtableDataClient + return BigtableDataClient @classmethod @@ -305,8 +307,6 @@ def test__start_background_channel_refresh_tasks_names(self): def test__ping_and_warm_instances(self): """test ping and warm with mocked asyncio.gather""" - from google.cloud.bigtable.data._sync.cross_sync import CrossSync - client_mock = mock.Mock() client_mock._execute_ping_and_warms = ( lambda *args: self._get_target_class()._execute_ping_and_warms( @@ -933,8 +933,6 @@ def test_close(self): assert client._channel_refresh_tasks == [] def test_close_with_timeout(self): - from google.cloud.bigtable.data._sync.cross_sync import CrossSync - pool_size = 7 expected_timeout = 19 client = self._make_client(project="project-id", pool_size=pool_size) @@ -1739,7 +1737,7 @@ def _get_operation_class(): return _ReadRowsOperationAsync def _make_client(self, *args, **kwargs): - return TestBigtableDataClientAsync._make_client(*args, **kwargs) + return TestBigtableDataClient._make_client(*args, **kwargs) def _make_table(self, *args, **kwargs): from google.cloud.bigtable.data._async.client import TableAsync @@ -2195,7 +2193,7 @@ def test_row_exists(self, return_value, expected_result): class TestReadRowsSharded: def _make_client(self, *args, **kwargs): - return TestBigtableDataClientAsync._make_client(*args, **kwargs) + return TestBigtableDataClient._make_client(*args, **kwargs) def test_read_rows_sharded_empty_query(self): with self._make_client() as client: @@ -2388,7 +2386,7 @@ def mock_call(*args, **kwargs): class TestSampleRowKeys: def _make_client(self, *args, **kwargs): - return TestBigtableDataClientAsync._make_client(*args, **kwargs) + return TestBigtableDataClient._make_client(*args, **kwargs) def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): from google.cloud.bigtable_v2.types import SampleRowKeysResponse @@ -2516,7 +2514,7 @@ def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): class TestTable: def _make_client(self, *args, **kwargs): - return TestBigtableDataClientAsync._make_client(*args, **kwargs) + return TestBigtableDataClient._make_client(*args, **kwargs) @staticmethod def _get_target_class(): diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py index 6d942f48c..afdf9f905 100644 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -15,19 +15,21 @@ # This file is automatically generated by sync_surface_generator.py. Do not edit. -from mock import AsyncMock -from tests.unit.data._async.test_mutations_batcher import Test_FlowControl -from unittest import mock -from unittest.mock import AsyncMock import asyncio -import mock import pytest import time +try: + from unittest import mock + from unittest.mock import AsyncMock +except ImportError: + import mock + from mock import AsyncMock + from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data import TableAsync from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete -import google.api_core.exceptions +import google.api_core.exceptions as core_exceptions import google.api_core.retry diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync/test_read_rows_acceptance.py index 9d6df6f9c..1f423232c 100644 --- a/tests/unit/data/_sync/test_read_rows_acceptance.py +++ b/tests/unit/data/_sync/test_read_rows_acceptance.py @@ -17,8 +17,7 @@ from __future__ import annotations from itertools import zip_longest -from tests.unit.v2_client.test_row_merger import ReadRowsTest -from tests.unit.v2_client.test_row_merger import TestFile +from tests.unit.v2_client.test_row_merger import ReadRowsTest, TestFile import mock import os import pytest From fdce0bccad5fffd59cac6dbf6d4ae257dc85b2b1 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 24 Jun 2024 15:52:26 -0700 Subject: [PATCH 100/360] fixed cross sync import conditional --- .../cloud/bigtable/data/_sync/_mutate_rows.py | 2 +- google/cloud/bigtable/data/_sync/client.py | 2 +- .../bigtable/data/_sync/mutations_batcher.py | 2 +- sync_surface_generator.py | 4 +-- tests/unit/data/_async/test_client.py | 27 ++++++++++++------- tests/unit/data/_sync/test_client.py | 20 +++++++++----- 6 files changed, 36 insertions(+), 21 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index 810f29661..7e50b803b 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -34,7 +34,7 @@ if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry - if CrossSync.is_async: + if CrossSync._Sync_Impl.is_async: pass else: from google.cloud.bigtable.data._sync.client import Table diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 260ef1e22..5ef6fd325 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -54,7 +54,7 @@ from google.cloud.client import ClientWithProject from google.cloud.environment_vars import BIGTABLE_EMULATOR -if CrossSync.is_async: +if CrossSync._Sync_Impl.is_async: pass else: from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index c750b8b78..8fa12022a 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry - if CrossSync.is_async: + if CrossSync._Sync_Impl.is_async: pass else: from google.cloud.bigtable.data._sync.client import Table diff --git a/sync_surface_generator.py b/sync_surface_generator.py index d8fbc87fa..224d81bf4 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -266,8 +266,8 @@ def get_imports(self, filename): with open(filename, "r") as f: full_tree = ast.parse(f.read(), filename) imports = [node for node in full_tree.body if isinstance(node, (ast.Import, ast.ImportFrom))] - if_imports = [node for node in full_tree.body if isinstance(node, ast.If) and any(isinstance(n, (ast.Import, ast.ImportFrom)) for n in node.body)] - try_imports = [node for node in full_tree.body if isinstance(node, ast.Try)] + if_imports = [self.visit(node) for node in full_tree.body if isinstance(node, ast.If) and any(isinstance(n, (ast.Import, ast.ImportFrom)) for n in node.body)] + try_imports = [self.visit(node) for node in full_tree.body if isinstance(node, ast.Try)] return set(imports + if_imports + try_imports) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 66bba4b97..0ea9d5dc7 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -24,15 +24,6 @@ from google.cloud.bigtable.data import mutations from google.auth.credentials import AnonymousCredentials from google.cloud.bigtable_v2.types import ReadRowsResponse -from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, -) -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, -) -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel as PooledChannelAsync, -) from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.api_core import exceptions as core_exceptions from google.cloud.bigtable.data.exceptions import InvalidChunk @@ -52,6 +43,24 @@ import mock # type: ignore from mock import AsyncMock # type: ignore +if CrossSync.is_async: + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel as PooledChannelAsync, + ) +else: + from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledBigtableGrpcTransport, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledChannel, + ) @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestBigtableDataClient", diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index b05bc94bf..9eb849b18 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -40,15 +40,21 @@ from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery -from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, -) -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel as PooledChannelAsync, -) from google.cloud.bigtable_v2.types import ReadRowsResponse +if CrossSync._Sync_Impl.is_async: + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel as PooledChannelAsync, + ) +else: + pass + class TestBigtableDataClient: @staticmethod From adc8bb7064ed161429136b63579fa87dbf382430 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 24 Jun 2024 15:55:30 -0700 Subject: [PATCH 101/360] fixed import --- google/cloud/bigtable/data/_async/_mutate_rows.py | 4 +++- google/cloud/bigtable/data/_sync/_mutate_rows.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 439b2629c..dc2c81052 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -34,6 +34,9 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if not CrossSync.is_async: + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto + if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -47,7 +50,6 @@ from google.cloud.bigtable_v2.services.bigtable.client import ( # noqa: F401 BigtableClient, ) - from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto @dataclass class _EntryWithProto: diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index 7e50b803b..738df0668 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -39,7 +39,8 @@ else: from google.cloud.bigtable.data._sync.client import Table from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient - from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto +if not CrossSync._Sync_Impl.is_async: + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto class _MutateRowsOperation: From 31fb77a2e203ca3ceaf518b1cf68f9a20f5b55f7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 24 Jun 2024 16:09:19 -0700 Subject: [PATCH 102/360] got rpc tests passing --- tests/unit/data/_async/test_client.py | 24 ++++++----- tests/unit/data/_sync/test_client.py | 57 +++++++++++---------------- 2 files changed, 33 insertions(+), 48 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 0ea9d5dc7..a3d18131b 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -53,6 +53,8 @@ from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( PooledChannel as PooledChannelAsync, ) + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + from google.cloud.bigtable.data._async.client import TableAsync else: from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( @@ -61,6 +63,8 @@ from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( PooledChannel, ) + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + from google.cloud.bigtable.data._sync.client import Table @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestBigtableDataClient", @@ -1107,7 +1111,7 @@ def test_client_ctor_sync(self): @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestTable", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "TableAsync": "Table"}, ) class TestTableAsync: def _make_client(self, *args, **kwargs): @@ -1115,8 +1119,6 @@ def _make_client(self, *args, **kwargs): @staticmethod def _get_target_class(): - from google.cloud.bigtable.data._async.client import TableAsync - return TableAsync @property @@ -1429,7 +1431,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestReadRows", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "__aiter__": "__iter__", "__anext__": "__next__", "StopAsyncIteration": "StopIteration", "_ReadRowsOperationAsync": "_ReadRowsOperation", "TestTableAsync": "TestTable"}, ) class TestReadRowsAsync: """ @@ -1438,16 +1440,12 @@ class TestReadRowsAsync: @staticmethod def _get_operation_class(): - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - return _ReadRowsOperationAsync def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) def _make_table(self, *args, **kwargs): - from google.cloud.bigtable.data._async.client import TableAsync - client_mock = mock.Mock() client_mock._register_instance.side_effect = ( lambda *args, **kwargs: asyncio.sleep(0) @@ -1463,7 +1461,7 @@ def _make_table(self, *args, **kwargs): ) client_mock._gapic_client.table_path.return_value = kwargs["table_id"] client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] - return TableAsync(client_mock, *args, **kwargs) + return TestTableAsync._get_target_class()(client_mock, *args, **kwargs) def _make_stats(self): from google.cloud.bigtable_v2.types import RequestStats @@ -1513,7 +1511,7 @@ async def __anext__(self): self.idx += 1 if len(self.chunk_list) > self.idx: if sleep_time: - await asyncio.sleep(self.sleep_time) + await CrossSync.sleep(self.sleep_time) chunk = self.chunk_list[self.idx] if isinstance(chunk, Exception): raise chunk @@ -1939,7 +1937,7 @@ async def test_row_exists(self, return_value, expected_result): @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestReadRowsSharded", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "TestReadRowsAsync": "TestReadRows"}, ) class TestReadRowsShardedAsync: def _make_client(self, *args, **kwargs): @@ -2143,7 +2141,7 @@ async def test_read_rows_sharded_negative_batch_timeout(self): from google.api_core.exceptions import DeadlineExceeded async def mock_call(*args, **kwargs): - await asyncio.sleep(0.05) + await CrossSync.sleep(0.05) return [mock.Mock()] async with self._make_client() as client: @@ -2163,7 +2161,7 @@ async def mock_call(*args, **kwargs): @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestSampleRowKeys", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "AsyncMock": "mock.Mock"}, ) class TestSampleRowKeysAsync: def _make_client(self, *args, **kwargs): diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 9eb849b18..c2a9271dc 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -53,7 +53,8 @@ PooledChannel as PooledChannelAsync, ) else: - pass + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + from google.cloud.bigtable.data._sync.client import Table class TestBigtableDataClient: @@ -1738,16 +1739,12 @@ class TestReadRows: @staticmethod def _get_operation_class(): - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - - return _ReadRowsOperationAsync + return _ReadRowsOperation def _make_client(self, *args, **kwargs): return TestBigtableDataClient._make_client(*args, **kwargs) def _make_table(self, *args, **kwargs): - from google.cloud.bigtable.data._async.client import TableAsync - client_mock = mock.Mock() client_mock._register_instance.side_effect = ( lambda *args, **kwargs: asyncio.sleep(0) @@ -1763,7 +1760,7 @@ def _make_table(self, *args, **kwargs): ) client_mock._gapic_client.table_path.return_value = kwargs["table_id"] client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] - return TableAsync(client_mock, *args, **kwargs) + return TestTable._get_target_class()(client_mock, *args, **kwargs) def _make_stats(self): from google.cloud.bigtable_v2.types import RequestStats @@ -1804,20 +1801,20 @@ def __init__(self, chunk_list, sleep_time): self.idx = -1 self.sleep_time = sleep_time - def __aiter__(self): + def __iter__(self): return self - def __anext__(self): + def __next__(self): self.idx += 1 if len(self.chunk_list) > self.idx: if sleep_time: - asyncio.sleep(self.sleep_time) + CrossSync._Sync_Impl.sleep(self.sleep_time) chunk = self.chunk_list[self.idx] if isinstance(chunk, Exception): raise chunk else: return ReadRowsResponse(chunks=[chunk]) - raise StopAsyncIteration + raise StopIteration def cancel(self): pass @@ -2216,9 +2213,9 @@ def test_read_rows_sharded_multiple_queries(self): table.client._gapic_client, "read_rows" ) as read_rows: read_rows.side_effect = ( - lambda *args, **kwargs: TestReadRowsAsync._make_gapic_stream( + lambda *args, **kwargs: TestReadRows._make_gapic_stream( [ - TestReadRowsAsync._make_chunk(row_key=k) + TestReadRows._make_chunk(row_key=k) for k in args[0].rows.row_keys ] ) @@ -2370,7 +2367,7 @@ def test_read_rows_sharded_negative_batch_timeout(self): from google.api_core.exceptions import DeadlineExceeded def mock_call(*args, **kwargs): - asyncio.sleep(0.05) + CrossSync._Sync_Impl.sleep(0.05) return [mock.Mock()] with self._make_client() as client: @@ -2406,7 +2403,7 @@ def test_sample_row_keys(self): with self._make_client() as client: with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", mock.Mock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream(samples) result = table.sample_row_keys() @@ -2440,7 +2437,7 @@ def test_sample_row_keys_default_timeout(self): default_attempt_timeout=expected_timeout, ) as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", mock.Mock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) result = table.sample_row_keys() @@ -2460,7 +2457,7 @@ def test_sample_row_keys_gapic_params(self): instance, table_id, app_profile_id=expected_profile ) as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", mock.Mock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) table.sample_row_keys(attempt_timeout=expected_timeout) @@ -2485,7 +2482,7 @@ def test_sample_row_keys_retryable_errors(self, retryable_exception): with self._make_client() as client: with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", mock.Mock() ) as sample_row_keys: sample_row_keys.side_effect = retryable_exception("mock") with pytest.raises(DeadlineExceeded) as e: @@ -2511,7 +2508,7 @@ def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): with self._make_client() as client: with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", mock.Mock() ) as sample_row_keys: sample_row_keys.side_effect = non_retryable_exception("mock") with pytest.raises(non_retryable_exception): @@ -2524,9 +2521,7 @@ def _make_client(self, *args, **kwargs): @staticmethod def _get_target_class(): - from google.cloud.bigtable.data._async.client import TableAsync - - return TableAsync + return Table @property def is_async(self): @@ -2594,13 +2589,11 @@ def test_table_ctor(self): def test_table_ctor_defaults(self): """should provide default timeout values and app_profile_id""" - from google.cloud.bigtable.data._async.client import TableAsync - expected_table_id = "table-id" expected_instance_id = "instance-id" client = self._make_client() assert not client._active_instances - table = TableAsync(client, expected_instance_id, expected_table_id) + table = Table(client, expected_instance_id, expected_table_id) asyncio.sleep(0) assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id @@ -2616,8 +2609,6 @@ def test_table_ctor_defaults(self): def test_table_ctor_invalid_timeout_values(self): """bad timeout values should raise ValueError""" - from google.cloud.bigtable.data._async.client import TableAsync - client = self._make_client() timeout_pairs = [ ("default_operation_timeout", "default_attempt_timeout"), @@ -2632,19 +2623,17 @@ def test_table_ctor_invalid_timeout_values(self): ] for operation_timeout, attempt_timeout in timeout_pairs: with pytest.raises(ValueError) as e: - TableAsync(client, "", "", **{attempt_timeout: -1}) + Table(client, "", "", **{attempt_timeout: -1}) assert "attempt_timeout must be greater than 0" in str(e.value) with pytest.raises(ValueError) as e: - TableAsync(client, "", "", **{operation_timeout: -1}) + Table(client, "", "", **{operation_timeout: -1}) assert "operation_timeout must be greater than 0" in str(e.value) client.close() def test_table_ctor_sync(self): - from google.cloud.bigtable.data._async.client import TableAsync - client = mock.Mock() with pytest.raises(RuntimeError) as e: - TableAsync(client, "instance-id", "table-id") + Table(client, "instance-id", "table-id") assert e.match("TableAsync must be created within an async event loop context.") @pytest.mark.parametrize( @@ -2754,15 +2743,13 @@ def test_customizable_retryable_errors( @pytest.mark.parametrize("include_app_profile", [True, False]) def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): """check that all requests attach proper metadata headers""" - from google.cloud.bigtable.data import TableAsync - profile = "profile" if include_app_profile else None with mock.patch.object( BigtableAsyncClient, gapic_fn, mock.AsyncMock() ) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") with self._make_client() as client: - table = TableAsync(client, "instance-id", "table-id", profile) + table = Table(client, "instance-id", "table-id", profile) try: test_fn = table.__getattribute__(fn_name) maybe_stream = test_fn(*fn_args) From fc44b307cf24dba0ced65c290911fab9b1a3ef0b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 24 Jun 2024 16:17:43 -0700 Subject: [PATCH 103/360] removed custom is_async --- tests/unit/data/_async/test_client.py | 42 ++++++++----------- tests/unit/data/_sync/test_client.py | 59 +++++++++++++++------------ 2 files changed, 51 insertions(+), 50 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index a3d18131b..0e9e1fa21 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -100,10 +100,6 @@ def _make_client(cls, *args, use_emulator=True, **kwargs): with mock.patch.dict(os.environ, env_mask): return cls._get_target_class()(*args, **kwargs) - @property - def is_async(self): - return True - @pytest.mark.asyncio async def test_ctor(self): expected_project = "project-id" @@ -134,7 +130,7 @@ async def test_ctor_super_inits(self): credentials = AnonymousCredentials() client_options = {"api_endpoint": "foo.bar:1234"} options_parsed = client_options_lib.from_dict(client_options) - asyncio_portion = "-async" if self.is_async else "" + asyncio_portion = "-async" if CrossSync.is_async else "" transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: bigtable_client_init.return_value = None @@ -191,7 +187,7 @@ async def test_ctor_dict_options(self): @pytest.mark.asyncio async def test_veneer_grpc_headers(self): - client_component = "data-async" if self.is_async else "data" + client_component = "data-async" if CrossSync.is_async else "data" VENEER_HEADER_REGEX = re.compile( r"gapic\/[0-9]+\.[\w.-]+ gax\/[0-9]+\.[\w.-]+ gccl\/[0-9]+\.[\w.-]+-" + client_component @@ -200,7 +196,7 @@ async def test_veneer_grpc_headers(self): # client_info should be populated with headers to # detect as a veneer client - if self.is_async: + if CrossSync_async: patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") else: patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") @@ -264,7 +260,7 @@ async def test_channel_pool_rotation(self): async def test_channel_pool_replace(self): import time - sleep_module = asyncio if self.is_async else time + sleep_module = asyncio if CrossSync.is_async else time with mock.patch.object(sleep_module, "sleep"): pool_size = 7 client = self._make_client(project="project-id", pool_size=pool_size) @@ -281,7 +277,7 @@ async def test_channel_pool_replace(self): replace_idx, grace=grace_period, new_channel=new_channel ) close.assert_called_once() - if self.is_async: + if CrossSync.is_async: close.assert_called_once_with(grace=grace_period) close.assert_awaited_once() assert client.transport._grpc_channel._pool[replace_idx] == new_channel @@ -324,7 +320,7 @@ async def test__start_background_channel_refresh(self, pool_size): client._start_background_channel_refresh() assert len(client._channel_refresh_tasks) == pool_size for task in client._channel_refresh_tasks: - if self.is_async: + if CrossSync.is_async: assert isinstance(task, asyncio.Task) else: assert isinstance(task, concurrent.futures.Future) @@ -387,7 +383,7 @@ async def test__ping_and_warm_instances(self): # expect one partial for each instance partial_list = gather.call_args.args[0] assert len(partial_list) == 4 - if self.is_async: + if CrossSync.is_async: gather.assert_awaited_once() # check grpc call arguments grpc_call_args = channel.unary_unary().call_args_list @@ -420,11 +416,11 @@ async def test__ping_and_warm_single_instance(self): ) ) gather_tuple = ( - (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") + (asyncio, "gather") if CrossSync.is_async else (client_mock._executor, "submit") ) with mock.patch.object(*gather_tuple, AsyncMock()) as gather: gather.side_effect = lambda *args, **kwargs: [mock.Mock() for _ in args] - if self.is_async: + if CrossSync.is_async: # simulate gather by returning the same number of items as passed in # gather is expected to return None for each coroutine passed gather.side_effect = lambda *args, **kwargs: [None for _ in args] @@ -476,7 +472,7 @@ async def test__manage_channel_first_sleep( with mock.patch.object(time, "monotonic") as monotonic: monotonic.return_value = 0 sleep_tuple = ( - (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") ) with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = asyncio.CancelledError @@ -509,7 +505,7 @@ async def test__manage_channel_ping_and_warm(self): new_channel = mock.Mock() client_mock.transport.grpc_channel._create_channel.return_value = new_channel # should ping an warm all new channels, and old channels if sleeping - sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + sleep_tuple = (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") with mock.patch.object(*sleep_tuple): # stop process after replace_channel is called client_mock.transport.replace_channel.side_effect = asyncio.CancelledError @@ -563,7 +559,7 @@ async def test__manage_channel_sleeps( with mock.patch.object(time, "time") as time_mock: time_mock.return_value = 0 sleep_tuple = ( - (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") ) with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = [None for i in range(num_cycles - 1)] + [ @@ -592,7 +588,7 @@ async def test__manage_channel_random(self): import random import threading - sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + sleep_tuple = (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") with mock.patch.object(*sleep_tuple) as sleep: with mock.patch.object(random, "uniform") as uniform: uniform.return_value = 0 @@ -627,14 +623,14 @@ async def test__manage_channel_refresh(self, num_cycles): expected_grace = 9 expected_refresh = 0.5 channel_idx = 1 - grpc_lib = grpc.aio if self.is_async else grpc + grpc_lib = grpc.aio if CrossSync.is_async else grpc new_channel = grpc_lib.insecure_channel("localhost:8080") with mock.patch.object( PooledBigtableGrpcAsyncIOTransport, "replace_channel" ) as replace_channel: sleep_tuple = ( - (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") ) with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = [None for i in range(num_cycles)] + [ @@ -1111,7 +1107,7 @@ def test_client_ctor_sync(self): @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestTable", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "TableAsync": "Table"}, + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "TableAsync": "Table", "BigtableAsyncClient": "BigtableClient"}, ) class TestTableAsync: def _make_client(self, *args, **kwargs): @@ -1121,10 +1117,6 @@ def _make_client(self, *args, **kwargs): def _get_target_class(): return TableAsync - @property - def is_async(self): - return True - @pytest.mark.asyncio async def test_table_ctor(self): from google.cloud.bigtable.data._helpers import _WarmedInstanceKey @@ -1346,7 +1338,7 @@ async def test_customizable_retryable_errors( retry_fn = "retry_target" if is_stream: retry_fn += "_stream" - if self.is_async: + if CrossSync.is_async: retry_fn = f"CrossSync.{retry_fn}" else: retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index c2a9271dc..346e274a6 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -53,6 +53,7 @@ PooledChannel as PooledChannelAsync, ) else: + from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation from google.cloud.bigtable.data._sync.client import Table @@ -85,10 +86,6 @@ def _make_client(cls, *args, use_emulator=True, **kwargs): with mock.patch.dict(os.environ, env_mask): return cls._get_target_class()(*args, **kwargs) - @property - def is_async(self): - return True - def test_ctor(self): expected_project = "project-id" expected_pool_size = 11 @@ -117,7 +114,7 @@ def test_ctor_super_inits(self): credentials = AnonymousCredentials() client_options = {"api_endpoint": "foo.bar:1234"} options_parsed = client_options_lib.from_dict(client_options) - asyncio_portion = "-async" if self.is_async else "" + asyncio_portion = "-async" if CrossSync._Sync_Impl.is_async else "" transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: bigtable_client_init.return_value = None @@ -170,13 +167,13 @@ def test_ctor_dict_options(self): client.close() def test_veneer_grpc_headers(self): - client_component = "data-async" if self.is_async else "data" + client_component = "data-async" if CrossSync._Sync_Impl.is_async else "data" VENEER_HEADER_REGEX = re.compile( "gapic\\/[0-9]+\\.[\\w.-]+ gax\\/[0-9]+\\.[\\w.-]+ gccl\\/[0-9]+\\.[\\w.-]+-" + client_component + " gl-python\\/[0-9]+\\.[\\w.-]+ grpc\\/[0-9]+\\.[\\w.-]+" ) - if self.is_async: + if CrossSync_async: patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") else: patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") @@ -233,7 +230,7 @@ def test_channel_pool_rotation(self): def test_channel_pool_replace(self): import time - sleep_module = asyncio if self.is_async else time + sleep_module = asyncio if CrossSync._Sync_Impl.is_async else time with mock.patch.object(sleep_module, "sleep"): pool_size = 7 client = self._make_client(project="project-id", pool_size=pool_size) @@ -250,7 +247,7 @@ def test_channel_pool_replace(self): replace_idx, grace=grace_period, new_channel=new_channel ) close.assert_called_once() - if self.is_async: + if CrossSync._Sync_Impl.is_async: close.assert_called_once_with(grace=grace_period) close.assert_awaited_once() assert client.transport._grpc_channel._pool[replace_idx] == new_channel @@ -288,7 +285,7 @@ def test__start_background_channel_refresh(self, pool_size): client._start_background_channel_refresh() assert len(client._channel_refresh_tasks) == pool_size for task in client._channel_refresh_tasks: - if self.is_async: + if CrossSync._Sync_Impl.is_async: assert isinstance(task, asyncio.Task) else: assert isinstance(task, concurrent.futures.Future) @@ -344,7 +341,7 @@ def test__ping_and_warm_instances(self): gather.assert_called_once() partial_list = gather.call_args.args[0] assert len(partial_list) == 4 - if self.is_async: + if CrossSync._Sync_Impl.is_async: gather.assert_awaited_once() grpc_call_args = channel.unary_unary().call_args_list for idx, (_, kwargs) in enumerate(grpc_call_args): @@ -373,11 +370,13 @@ def test__ping_and_warm_single_instance(self): ) ) gather_tuple = ( - (asyncio, "gather") if self.is_async else (client_mock._executor, "submit") + (asyncio, "gather") + if CrossSync._Sync_Impl.is_async + else (client_mock._executor, "submit") ) with mock.patch.object(*gather_tuple, AsyncMock()) as gather: gather.side_effect = lambda *args, **kwargs: [mock.Mock() for _ in args] - if self.is_async: + if CrossSync._Sync_Impl.is_async: gather.side_effect = lambda *args, **kwargs: [None for _ in args] else: gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] @@ -414,7 +413,9 @@ def test__manage_channel_first_sleep( with mock.patch.object(time, "monotonic") as monotonic: monotonic.return_value = 0 sleep_tuple = ( - (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + (asyncio, "sleep") + if CrossSync._Sync_Impl.is_async + else (threading.Event, "wait") ) with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = asyncio.CancelledError @@ -443,7 +444,11 @@ def test__manage_channel_ping_and_warm(self): client_mock.transport.channels = channel_list new_channel = mock.Mock() client_mock.transport.grpc_channel._create_channel.return_value = new_channel - sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + sleep_tuple = ( + (asyncio, "sleep") + if CrossSync._Sync_Impl.is_async + else (threading.Event, "wait") + ) with mock.patch.object(*sleep_tuple): client_mock.transport.replace_channel.side_effect = asyncio.CancelledError ping_and_warm = client_mock._ping_and_warm_instances = AsyncMock() @@ -481,7 +486,9 @@ def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sle with mock.patch.object(time, "time") as time_mock: time_mock.return_value = 0 sleep_tuple = ( - (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + (asyncio, "sleep") + if CrossSync._Sync_Impl.is_async + else (threading.Event, "wait") ) with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = [None for i in range(num_cycles - 1)] + [ @@ -509,7 +516,11 @@ def test__manage_channel_random(self): import random import threading - sleep_tuple = (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + sleep_tuple = ( + (asyncio, "sleep") + if CrossSync._Sync_Impl.is_async + else (threading.Event, "wait") + ) with mock.patch.object(*sleep_tuple) as sleep: with mock.patch.object(random, "uniform") as uniform: uniform.return_value = 0 @@ -542,13 +553,15 @@ def test__manage_channel_refresh(self, num_cycles): expected_grace = 9 expected_refresh = 0.5 channel_idx = 1 - grpc_lib = grpc.aio if self.is_async else grpc + grpc_lib = grpc.aio if CrossSync._Sync_Impl.is_async else grpc new_channel = grpc_lib.insecure_channel("localhost:8080") with mock.patch.object( PooledBigtableGrpcAsyncIOTransport, "replace_channel" ) as replace_channel: sleep_tuple = ( - (asyncio, "sleep") if self.is_async else (threading.Event, "wait") + (asyncio, "sleep") + if CrossSync._Sync_Impl.is_async + else (threading.Event, "wait") ) with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = [None for i in range(num_cycles)] + [ @@ -2523,10 +2536,6 @@ def _make_client(self, *args, **kwargs): def _get_target_class(): return Table - @property - def is_async(self): - return True - def test_table_ctor(self): from google.cloud.bigtable.data._helpers import _WarmedInstanceKey @@ -2693,7 +2702,7 @@ def test_customizable_retryable_errors( retry_fn = "retry_target" if is_stream: retry_fn += "_stream" - if self.is_async: + if CrossSync._Sync_Impl.is_async: retry_fn = f"CrossSync.{retry_fn}" else: retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" @@ -2745,7 +2754,7 @@ def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): """check that all requests attach proper metadata headers""" profile = "profile" if include_app_profile else None with mock.patch.object( - BigtableAsyncClient, gapic_fn, mock.AsyncMock() + BigtableClient, gapic_fn, mock.AsyncMock() ) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") with self._make_client() as client: From 255e124e70d3789bc0c8b9c1bda6dfbcc681b44e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 24 Jun 2024 16:45:18 -0700 Subject: [PATCH 104/360] support dropping methods --- .../cloud/bigtable/data/_sync/cross_sync.py | 9 ++++- sync_surface_generator.py | 9 +++-- tests/unit/data/_async/test_client.py | 4 ++ tests/unit/data/_sync/test_client.py | 38 ------------------- 4 files changed, 17 insertions(+), 43 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 3b5f7ef1a..1f97f9c86 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -53,11 +53,18 @@ def decorator(func): return decorator + @staticmethod + def drop_method(*args, **kwargs): + def decorator(func): + return func + + return decorator + @classmethod def sync_output( cls, sync_path: str, - replace_symbols: dict["str", "str" | None] | None = None, + replace_symbols: dict["str", "str" | None ] | None = None, mypy_ignore: list[str] | None = None, ): replace_symbols = replace_symbols or {} diff --git a/sync_surface_generator.py b/sync_surface_generator.py index 224d81bf4..091a74492 100644 --- a/sync_surface_generator.py +++ b/sync_surface_generator.py @@ -117,12 +117,13 @@ def visit_AsyncFunctionDef(self, node): # TODO: make generic new_list = [] for decorator in node.decorator_list: - # check for cross_sync decorator - if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name) and decorator.func.value.id == "CrossSync": - decorator_type = decorator.func.attr - if decorator_type == "rename_sync": + # check for @CrossSync.x() decorators + if "CrossSync" in ast.dump(decorator): + if "rename_sync" in ast.dump(decorator): new_name = decorator.args[0].value node.name = new_name + elif "drop_method" in ast.dump(decorator): + return None else: new_list.append(decorator) node.decorator_list = new_list diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 0e9e1fa21..e51a82b7d 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -288,6 +288,7 @@ async def test_channel_pool_replace(self): assert client.transport._grpc_channel._pool[i] != start_pool[i] await client.close() + @CrossSync.drop_method @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test__start_background_channel_refresh_sync(self): # should raise RuntimeError if called in a sync context @@ -330,6 +331,7 @@ async def test__start_background_channel_refresh(self, pool_size): ping_and_warm.assert_any_call(channel) await client.close() + @CrossSync.drop_method @pytest.mark.asyncio @pytest.mark.skipif( sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" @@ -1091,6 +1093,7 @@ async def test_context_manager(self): # actually close the client await true_close + @CrossSync.drop_method def test_client_ctor_sync(self): # initializing client in a sync context should raise RuntimeError @@ -1239,6 +1242,7 @@ async def test_table_ctor_invalid_timeout_values(self): assert "operation_timeout must be greater than 0" in str(e.value) await client.close() + @CrossSync.drop_method def test_table_ctor_sync(self): # initializing client in a sync context should raise RuntimeError from google.cloud.bigtable.data._async.client import TableAsync diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 346e274a6..4a1cda888 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -258,12 +258,6 @@ def test_channel_pool_replace(self): assert client.transport._grpc_channel._pool[i] != start_pool[i] client.close() - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test__start_background_channel_refresh_sync(self): - client = self._make_client(project="project-id", use_emulator=False) - with pytest.raises(RuntimeError): - client._start_background_channel_refresh() - def test__start_background_channel_refresh_tasks_exist(self): client = self._make_client(project="project-id", use_emulator=False) assert len(client._channel_refresh_tasks) > 0 @@ -295,20 +289,6 @@ def test__start_background_channel_refresh(self, pool_size): ping_and_warm.assert_any_call(channel) client.close() - @pytest.mark.skipif( - sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" - ) - def test__start_background_channel_refresh_tasks_names(self): - pool_size = 3 - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - for i in range(pool_size): - name = client._channel_refresh_tasks[i].get_name() - assert str(i) in name - assert "BigtableDataClientAsync channel refresh " in name - client.close() - def test__ping_and_warm_instances(self): """test ping and warm with mocked asyncio.gather""" client_mock = mock.Mock() @@ -982,18 +962,6 @@ def test_context_manager(self): close_mock.assert_awaited() true_close - def test_client_ctor_sync(self): - with pytest.warns(RuntimeWarning) as warnings: - client = self._make_client(project="project-id", use_emulator=False) - expected_warning = [w for w in warnings if "client.py" in w.filename] - assert len(expected_warning) == 1 - assert ( - "BigtableDataClientAsync should be started in an asyncio event loop." - in str(expected_warning[0].message) - ) - assert client.project == "project-id" - assert client._channel_refresh_tasks == [] - class TestBulkMutateRows: def _make_client(self, *args, **kwargs): @@ -2639,12 +2607,6 @@ def test_table_ctor_invalid_timeout_values(self): assert "operation_timeout must be greater than 0" in str(e.value) client.close() - def test_table_ctor_sync(self): - client = mock.Mock() - with pytest.raises(RuntimeError) as e: - Table(client, "instance-id", "table-id") - assert e.match("TableAsync must be created within an async event loop context.") - @pytest.mark.parametrize( "fn_name,fn_args,is_stream,extra_retryables", [ From 87aecb328053412989643b5f1c0170654016cf7f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 24 Jun 2024 17:08:00 -0700 Subject: [PATCH 105/360] got tests passing --- tests/unit/data/_async/test_client.py | 60 +++++++------- tests/unit/data/_sync/test_client.py | 108 +++++++++++--------------- 2 files changed, 71 insertions(+), 97 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index e51a82b7d..ce6acdcc4 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -20,7 +20,6 @@ import pytest -from google.api_core import grpc_helpers_async from google.cloud.bigtable.data import mutations from google.auth.credentials import AnonymousCredentials from google.cloud.bigtable_v2.types import ReadRowsResponse @@ -44,6 +43,7 @@ from mock import AsyncMock # type: ignore if CrossSync.is_async: + from google.api_core import grpc_helpers_async from google.cloud.bigtable_v2.services.bigtable.async_client import ( BigtableAsyncClient, ) @@ -54,8 +54,9 @@ PooledChannel as PooledChannelAsync, ) from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async.client import TableAsync, BigtableDataClientAsync else: + from google.api_core import grpc_helpers from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( PooledBigtableGrpcTransport, @@ -64,22 +65,25 @@ PooledChannel, ) from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - from google.cloud.bigtable.data._sync.client import Table + from google.cloud.bigtable.data._sync.client import Table, BigtableDataClient @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestBigtableDataClient", + replace_symbols={ + "TestTableAsync": "TestTable", + "BigtableDataClientAsync": "BigtableDataClient", + "TableAsync": "Table", + "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", + "grpc_helpers_async": "grpc_helpers", + "PooledChannelAsync": "PooledChannel", + "BigtableAsyncClient": "BigtableClient", + "AsyncMock": "mock.Mock", + } ) class TestBigtableDataClientAsync: @staticmethod def _get_target_class(): - if CrossSync.is_async: - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - return BigtableDataClientAsync - else: - from google.cloud.bigtable.data._sync.client import BigtableDataClient - - return BigtableDataClient + return BigtableDataClientAsync @classmethod def _make_client(cls, *args, use_emulator=True, **kwargs): @@ -196,7 +200,7 @@ async def test_veneer_grpc_headers(self): # client_info should be populated with headers to # detect as a veneer client - if CrossSync_async: + if CrossSync.is_async: patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") else: patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") @@ -417,18 +421,8 @@ async def test__ping_and_warm_single_instance(self): client_mock, *args ) ) - gather_tuple = ( - (asyncio, "gather") if CrossSync.is_async else (client_mock._executor, "submit") - ) - with mock.patch.object(*gather_tuple, AsyncMock()) as gather: - gather.side_effect = lambda *args, **kwargs: [mock.Mock() for _ in args] - if CrossSync.is_async: - # simulate gather by returning the same number of items as passed in - # gather is expected to return None for each coroutine passed - gather.side_effect = lambda *args, **kwargs: [None for _ in args] - else: - # submit is expected to call the function passed, and return the result - gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] + with mock.patch.object(CrossSync, "gather_partials", AsyncMock()) as gather: + gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] channel = mock.Mock() # test with large set of instances client_mock._active_instances = [mock.Mock()] * 100 @@ -473,10 +467,7 @@ async def test__manage_channel_first_sleep( with mock.patch.object(time, "monotonic") as monotonic: monotonic.return_value = 0 - sleep_tuple = ( - (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: + with mock.patch.object(CrossSync, "event_wait") as sleep: sleep.side_effect = asyncio.CancelledError try: client = self._make_client(project="project-id") @@ -485,7 +476,7 @@ async def test__manage_channel_first_sleep( except asyncio.CancelledError: pass sleep.assert_called_once() - call_time = sleep.call_args[0][0] + call_time = sleep.call_args[0][1] assert ( abs(call_time - expected_sleep) < 0.1 ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" @@ -579,7 +570,7 @@ async def test__manage_channel_sleeps( except asyncio.CancelledError: pass assert sleep.call_count == num_cycles - total_sleep = sum([call[0][0] for call in sleep.call_args_list]) + total_sleep = sum([call[1]["timeout"] for call in sleep.call_args_list]) assert ( abs(total_sleep - expected_sleep) < 0.1 ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" @@ -1056,7 +1047,8 @@ async def test_close(self): ) as close_mock: await client.close() close_mock.assert_called_once() - close_mock.assert_awaited() + if CrossSync.is_async: + close_mock.assert_awaited() for task in tasks_list: assert task.done() assert client._channel_refresh_tasks == [] @@ -1070,7 +1062,8 @@ async def test_close_with_timeout(self): with mock.patch.object(CrossSync, "wait", AsyncMock()) as wait_for_mock: await client.close(timeout=expected_timeout) wait_for_mock.assert_called_once() - wait_for_mock.assert_awaited() + if CrossSync.is_async: + wait_for_mock.assert_awaited() assert wait_for_mock.call_args[1]["timeout"] == expected_timeout client._channel_refresh_tasks = tasks await client.close() @@ -1089,7 +1082,8 @@ async def test_context_manager(self): assert client._active_instances == set() close_mock.assert_not_called() close_mock.assert_called_once() - close_mock.assert_awaited() + if CrossSync.is_async: + close_mock.assert_awaited() # actually close the client await true_close diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 4a1cda888..1e241683d 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -20,7 +20,6 @@ import grpc import pytest import re -import sys try: from unittest import mock @@ -30,7 +29,6 @@ from mock import AsyncMock from google.api_core import exceptions as core_exceptions -from google.api_core import grpc_helpers_async from google.auth.credentials import AnonymousCredentials from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data import mutations @@ -43,32 +41,24 @@ from google.cloud.bigtable_v2.types import ReadRowsResponse if CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel as PooledChannelAsync, - ) + pass else: + from google.api_core import grpc_helpers from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledBigtableGrpcTransport, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledChannel, + ) from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - from google.cloud.bigtable.data._sync.client import Table + from google.cloud.bigtable.data._sync.client import Table, BigtableDataClient class TestBigtableDataClient: @staticmethod def _get_target_class(): - if CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - return BigtableDataClientAsync - else: - from google.cloud.bigtable.data._sync.client import BigtableDataClient - - return BigtableDataClient + return BigtableDataClient @classmethod def _make_client(cls, *args, use_emulator=True, **kwargs): @@ -116,7 +106,7 @@ def test_ctor_super_inits(self): options_parsed = client_options_lib.from_dict(client_options) asyncio_portion = "-async" if CrossSync._Sync_Impl.is_async else "" transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" - with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: + with mock.patch.object(BigtableClient, "__init__") as bigtable_client_init: bigtable_client_init.return_value = None with mock.patch.object( ClientWithProject, "__init__" @@ -147,7 +137,7 @@ def test_ctor_dict_options(self): from google.api_core.client_options import ClientOptions client_options = {"api_endpoint": "foo.bar:1234"} - with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: + with mock.patch.object(BigtableClient, "__init__") as bigtable_client_init: try: self._make_client(client_options=client_options) except TypeError: @@ -173,7 +163,7 @@ def test_veneer_grpc_headers(self): + client_component + " gl-python\\/[0-9]+\\.[\\w.-]+ grpc\\/[0-9]+\\.[\\w.-]+" ) - if CrossSync_async: + if CrossSync._Sync_Impl.is_async: patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") else: patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") @@ -195,7 +185,7 @@ def test_veneer_grpc_headers(self): def test_channel_pool_creation(self): pool_size = 14 with mock.patch.object( - grpc_helpers_async, "create_channel", AsyncMock() + grpc_helpers, "create_channel", mock.Mock() ) as create_channel: client = self._make_client(project="project-id", pool_size=pool_size) assert create_channel.call_count == pool_size @@ -208,7 +198,7 @@ def test_channel_pool_creation(self): def test_channel_pool_rotation(self): pool_size = 7 - with mock.patch.object(PooledChannelAsync, "next_channel") as next_channel: + with mock.patch.object(PooledChannel, "next_channel") as next_channel: client = self._make_client(project="project-id", pool_size=pool_size) assert len(client.transport._grpc_channel._pool) == pool_size next_channel.reset_mock() @@ -271,7 +261,7 @@ def test__start_background_channel_refresh(self, pool_size): import concurrent.futures with mock.patch.object( - self._get_target_class(), "_ping_and_warm_instances", AsyncMock() + self._get_target_class(), "_ping_and_warm_instances", mock.Mock() ) as ping_and_warm: client = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False @@ -298,7 +288,7 @@ def test__ping_and_warm_instances(self): ) ) with mock.patch.object( - CrossSync._Sync_Impl, "gather_partials", AsyncMock() + CrossSync._Sync_Impl, "gather_partials", mock.Mock() ) as gather: gather.side_effect = lambda partials, **kwargs: [None for _ in partials] channel = mock.Mock() @@ -349,17 +339,10 @@ def test__ping_and_warm_single_instance(self): client_mock, *args ) ) - gather_tuple = ( - (asyncio, "gather") - if CrossSync._Sync_Impl.is_async - else (client_mock._executor, "submit") - ) - with mock.patch.object(*gather_tuple, AsyncMock()) as gather: - gather.side_effect = lambda *args, **kwargs: [mock.Mock() for _ in args] - if CrossSync._Sync_Impl.is_async: - gather.side_effect = lambda *args, **kwargs: [None for _ in args] - else: - gather.side_effect = lambda fn, **kwargs: [fn(**kwargs)] + with mock.patch.object( + CrossSync._Sync_Impl, "gather_partials", mock.Mock() + ) as gather: + gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] channel = mock.Mock() client_mock._active_instances = [mock.Mock()] * 100 test_key = ("test-instance", "test-table", "test-app-profile") @@ -387,17 +370,11 @@ def test__ping_and_warm_single_instance(self): def test__manage_channel_first_sleep( self, refresh_interval, wait_time, expected_sleep ): - import threading import time with mock.patch.object(time, "monotonic") as monotonic: monotonic.return_value = 0 - sleep_tuple = ( - (asyncio, "sleep") - if CrossSync._Sync_Impl.is_async - else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: + with mock.patch.object(CrossSync._Sync_Impl, "event_wait") as sleep: sleep.side_effect = asyncio.CancelledError try: client = self._make_client(project="project-id") @@ -406,7 +383,7 @@ def test__manage_channel_first_sleep( except asyncio.CancelledError: pass sleep.assert_called_once() - call_time = sleep.call_args[0][0] + call_time = sleep.call_args[0][1] assert ( abs(call_time - expected_sleep) < 0.1 ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" @@ -431,7 +408,7 @@ def test__manage_channel_ping_and_warm(self): ) with mock.patch.object(*sleep_tuple): client_mock.transport.replace_channel.side_effect = asyncio.CancelledError - ping_and_warm = client_mock._ping_and_warm_instances = AsyncMock() + ping_and_warm = client_mock._ping_and_warm_instances = mock.Mock() try: channel_idx = 1 self._get_target_class()._manage_channel(client_mock, channel_idx, 10) @@ -486,7 +463,9 @@ def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sle except asyncio.CancelledError: pass assert sleep.call_count == num_cycles - total_sleep = sum([call[0][0] for call in sleep.call_args_list]) + total_sleep = sum( + [call[1]["timeout"] for call in sleep.call_args_list] + ) assert ( abs(total_sleep - expected_sleep) < 0.1 ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" @@ -536,7 +515,7 @@ def test__manage_channel_refresh(self, num_cycles): grpc_lib = grpc.aio if CrossSync._Sync_Impl.is_async else grpc new_channel = grpc_lib.insecure_channel("localhost:8080") with mock.patch.object( - PooledBigtableGrpcAsyncIOTransport, "replace_channel" + PooledBigtableGrpcTransport, "replace_channel" ) as replace_channel: sleep_tuple = ( (asyncio, "sleep") @@ -548,7 +527,7 @@ def test__manage_channel_refresh(self, num_cycles): asyncio.CancelledError ] with mock.patch.object( - grpc_helpers_async, "create_channel" + grpc_helpers, "create_channel" ) as create_channel: create_channel.return_value = new_channel with mock.patch.object( @@ -591,7 +570,7 @@ def test__register_instance(self): ) mock_channels = [mock.Mock() for i in range(5)] client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = AsyncMock() + client_mock._ping_and_warm_instances = mock.Mock() table_mock = mock.Mock() self._get_target_class()._register_instance( client_mock, "instance-1", table_mock @@ -667,7 +646,7 @@ def test__register_instance_state( ) mock_channels = [mock.Mock() for i in range(5)] client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = AsyncMock() + client_mock._ping_and_warm_instances = mock.Mock() table_mock = mock.Mock() for instance, table, profile in insert_instances: table_mock.table_name = table @@ -814,7 +793,7 @@ def test_get_table(self): expected_instance_id, expected_table_id, expected_app_profile_id ) asyncio.sleep(0) - assert isinstance(table, TestTableAsync._get_target_class()) + assert isinstance(table, TestTable._get_target_class()) assert table.table_id == expected_table_id assert ( table.table_name @@ -838,7 +817,7 @@ def test_get_table_arg_passthrough(self): """All arguments passed in get_table should be sent to constructor""" with self._make_client(project="project-id") as client: with mock.patch.object( - TestTableAsync._get_target_class(), "__init__" + TestTable._get_target_class(), "__init__" ) as mock_constructor: mock_constructor.return_value = None assert not client._active_instances @@ -870,15 +849,13 @@ def test_get_table_context_manager(self): expected_instance_id = "instance-id" expected_app_profile_id = "app-profile-id" expected_project_id = "project-id" - with mock.patch.object( - TestTableAsync._get_target_class(), "close" - ) as close_mock: + with mock.patch.object(TestTable._get_target_class(), "close") as close_mock: with self._make_client(project=expected_project_id) as client: with client.get_table( expected_instance_id, expected_table_id, expected_app_profile_id ) as table: asyncio.sleep(0) - assert isinstance(table, TestTableAsync._get_target_class()) + assert isinstance(table, TestTable._get_target_class()) assert table.table_id == expected_table_id assert ( table.table_name @@ -923,11 +900,12 @@ def test_close(self): for task in client._channel_refresh_tasks: assert not task.done() with mock.patch.object( - PooledBigtableGrpcAsyncIOTransport, "close", AsyncMock() + PooledBigtableGrpcTransport, "close", mock.Mock() ) as close_mock: client.close() close_mock.assert_called_once() - close_mock.assert_awaited() + if CrossSync._Sync_Impl.is_async: + close_mock.assert_awaited() for task in tasks_list: assert task.done() assert client._channel_refresh_tasks == [] @@ -938,17 +916,18 @@ def test_close_with_timeout(self): client = self._make_client(project="project-id", pool_size=pool_size) tasks = list(client._channel_refresh_tasks) with mock.patch.object( - CrossSync._Sync_Impl, "wait", AsyncMock() + CrossSync._Sync_Impl, "wait", mock.Mock() ) as wait_for_mock: client.close(timeout=expected_timeout) wait_for_mock.assert_called_once() - wait_for_mock.assert_awaited() + if CrossSync._Sync_Impl.is_async: + wait_for_mock.assert_awaited() assert wait_for_mock.call_args[1]["timeout"] == expected_timeout client._channel_refresh_tasks = tasks client.close() def test_context_manager(self): - close_mock = AsyncMock() + close_mock = mock.Mock() true_close = None with self._make_client(project="project-id") as client: true_close = client.close() @@ -959,7 +938,8 @@ def test_context_manager(self): assert client._active_instances == set() close_mock.assert_not_called() close_mock.assert_called_once() - close_mock.assert_awaited() + if CrossSync._Sync_Impl.is_async: + close_mock.assert_awaited() true_close From a1426a50f3453da7ec7e4906ccfb32243f6eedc0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 24 Jun 2024 17:20:55 -0700 Subject: [PATCH 106/360] got test_client sync tests passing --- .../cloud/bigtable/data/_sync/cross_sync.py | 7 +--- tests/unit/data/_async/test_client.py | 42 ++++++++++++++----- tests/unit/data/_sync/test_client.py | 42 ++++++++++++------- 3 files changed, 61 insertions(+), 30 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 1f97f9c86..9a5fc5a08 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -54,11 +54,8 @@ def decorator(func): return decorator @staticmethod - def drop_method(*args, **kwargs): - def decorator(func): - return func - - return decorator + def drop_method(func): + return func @classmethod def sync_output( diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index ce6acdcc4..b7a5cdba0 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -115,7 +115,7 @@ async def test_ctor(self): credentials=expected_credentials, use_emulator=False, ) - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert client.project == expected_project assert len(client.transport._grpc_channel._pool) == expected_pool_size assert not client._active_instances @@ -329,7 +329,8 @@ async def test__start_background_channel_refresh(self, pool_size): assert isinstance(task, asyncio.Task) else: assert isinstance(task, concurrent.futures.Future) - await asyncio.sleep(0.1) + if CrossSync.is_async: + await asyncio.sleep(0.1) assert ping_and_warm.call_count == pool_size for channel in client.transport._grpc_channel._pool: ping_and_warm.assert_any_call(channel) @@ -570,7 +571,10 @@ async def test__manage_channel_sleeps( except asyncio.CancelledError: pass assert sleep.call_count == num_cycles - total_sleep = sum([call[1]["timeout"] for call in sleep.call_args_list]) + if CrossSync.is_async: + total_sleep = sum([call[0][0] for call in sleep.call_args_list]) + else: + total_sleep = sum([call[1]["timeout"] for call in sleep.call_args_list]) assert ( abs(total_sleep - expected_sleep) < 0.1 ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" @@ -842,12 +846,20 @@ async def test__multiple_table_registration(self): assert id(table_1) in client._instance_owners[instance_1_key] # duplicate table should register in instance_owners under same key async with client.get_table("instance_1", "table_1") as table_2: + assert table_2._register_instance_future is not None + if not CrossSync.is_async: + # give the background task time to run + table_2._register_instance_future.result() assert len(client._instance_owners[instance_1_key]) == 2 assert len(client._active_instances) == 1 assert id(table_1) in client._instance_owners[instance_1_key] assert id(table_2) in client._instance_owners[instance_1_key] # unique table should register in instance_owners and active_instances async with client.get_table("instance_1", "table_3") as table_3: + assert table_3._register_instance_future is not None + if not CrossSync.is_async: + # give the background task time to run + table_3._register_instance_future.result() instance_3_path = client._gapic_client.instance_path( client.project, "instance_1" ) @@ -879,7 +891,15 @@ async def test__multiple_instance_registration(self): async with self._make_client(project="project-id") as client: async with client.get_table("instance_1", "table_1") as table_1: + assert table_1._register_instance_future is not None + if not CrossSync.is_async: + # give the background task time to run + table_1._register_instance_future.result() async with client.get_table("instance_2", "table_2") as table_2: + assert table_2._register_instance_future is not None + if not CrossSync.is_async: + # give the background task time to run + table_2._register_instance_future.result() instance_1_path = client._gapic_client.instance_path( client.project, "instance_1" ) @@ -922,7 +942,7 @@ async def test_get_table(self): expected_table_id, expected_app_profile_id, ) - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert isinstance(table, TestTableAsync._get_target_class()) assert table.table_id == expected_table_id assert ( @@ -994,7 +1014,7 @@ async def test_get_table_context_manager(self): expected_table_id, expected_app_profile_id, ) as table: - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert isinstance(table, TestTableAsync._get_target_class()) assert table.table_id == expected_table_id assert ( @@ -1104,7 +1124,7 @@ def test_client_ctor_sync(self): @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestTable", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "TableAsync": "Table", "BigtableAsyncClient": "BigtableClient"}, + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "TableAsync": "Table", "BigtableAsyncClient": "BigtableClient", "AsyncMock": "mock.Mock"}, ) class TestTableAsync: def _make_client(self, *args, **kwargs): @@ -1142,7 +1162,7 @@ async def test_table_ctor(self): default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, ) - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id assert table.app_profile_id == expected_app_profile_id @@ -1194,7 +1214,7 @@ async def test_table_ctor_defaults(self): expected_instance_id, expected_table_id, ) - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id assert table.app_profile_id is None @@ -1394,7 +1414,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ profile = "profile" if include_app_profile else None with mock.patch.object( - BigtableAsyncClient, gapic_fn, mock.AsyncMock() + BigtableAsyncClient, gapic_fn, AsyncMock() ) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") async with self._make_client() as client: @@ -1438,10 +1458,10 @@ def _make_client(self, *args, **kwargs): def _make_table(self, *args, **kwargs): client_mock = mock.Mock() client_mock._register_instance.side_effect = ( - lambda *args, **kwargs: asyncio.sleep(0) + lambda *args, **kwargs: CrossSync.yield_to_event_loop() ) client_mock._remove_instance_registration.side_effect = ( - lambda *args, **kwargs: asyncio.sleep(0) + lambda *args, **kwargs: CrossSync.yield_to_event_loop() ) kwargs["instance_id"] = kwargs.get( "instance_id", args[0] if args else "instance" diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 1e241683d..4140cbe9c 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -86,7 +86,7 @@ def test_ctor(self): credentials=expected_credentials, use_emulator=False, ) - asyncio.sleep(0) + CrossSync._Sync_Impl.yield_to_event_loop() assert client.project == expected_project assert len(client.transport._grpc_channel._pool) == expected_pool_size assert not client._active_instances @@ -273,7 +273,8 @@ def test__start_background_channel_refresh(self, pool_size): assert isinstance(task, asyncio.Task) else: assert isinstance(task, concurrent.futures.Future) - asyncio.sleep(0.1) + if CrossSync._Sync_Impl.is_async: + asyncio.sleep(0.1) assert ping_and_warm.call_count == pool_size for channel in client.transport._grpc_channel._pool: ping_and_warm.assert_any_call(channel) @@ -463,9 +464,12 @@ def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sle except asyncio.CancelledError: pass assert sleep.call_count == num_cycles - total_sleep = sum( - [call[1]["timeout"] for call in sleep.call_args_list] - ) + if CrossSync._Sync_Impl.is_async: + total_sleep = sum([call[0][0] for call in sleep.call_args_list]) + else: + total_sleep = sum( + [call[1]["timeout"] for call in sleep.call_args_list] + ) assert ( abs(total_sleep - expected_sleep) < 0.1 ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" @@ -721,11 +725,17 @@ def test__multiple_table_registration(self): assert len(client._active_instances) == 1 assert id(table_1) in client._instance_owners[instance_1_key] with client.get_table("instance_1", "table_1") as table_2: + assert table_2._register_instance_future is not None + if not CrossSync._Sync_Impl.is_async: + table_2._register_instance_future.result() assert len(client._instance_owners[instance_1_key]) == 2 assert len(client._active_instances) == 1 assert id(table_1) in client._instance_owners[instance_1_key] assert id(table_2) in client._instance_owners[instance_1_key] with client.get_table("instance_1", "table_3") as table_3: + assert table_3._register_instance_future is not None + if not CrossSync._Sync_Impl.is_async: + table_3._register_instance_future.result() instance_3_path = client._gapic_client.instance_path( client.project, "instance_1" ) @@ -754,7 +764,13 @@ def test__multiple_instance_registration(self): with self._make_client(project="project-id") as client: with client.get_table("instance_1", "table_1") as table_1: + assert table_1._register_instance_future is not None + if not CrossSync._Sync_Impl.is_async: + table_1._register_instance_future.result() with client.get_table("instance_2", "table_2") as table_2: + assert table_2._register_instance_future is not None + if not CrossSync._Sync_Impl.is_async: + table_2._register_instance_future.result() instance_1_path = client._gapic_client.instance_path( client.project, "instance_1" ) @@ -792,7 +808,7 @@ def test_get_table(self): table = client.get_table( expected_instance_id, expected_table_id, expected_app_profile_id ) - asyncio.sleep(0) + CrossSync._Sync_Impl.yield_to_event_loop() assert isinstance(table, TestTable._get_target_class()) assert table.table_id == expected_table_id assert ( @@ -854,7 +870,7 @@ def test_get_table_context_manager(self): with client.get_table( expected_instance_id, expected_table_id, expected_app_profile_id ) as table: - asyncio.sleep(0) + CrossSync._Sync_Impl.yield_to_event_loop() assert isinstance(table, TestTable._get_target_class()) assert table.table_id == expected_table_id assert ( @@ -1708,10 +1724,10 @@ def _make_client(self, *args, **kwargs): def _make_table(self, *args, **kwargs): client_mock = mock.Mock() client_mock._register_instance.side_effect = ( - lambda *args, **kwargs: asyncio.sleep(0) + lambda *args, **kwargs: CrossSync._Sync_Impl.yield_to_event_loop() ) client_mock._remove_instance_registration.side_effect = ( - lambda *args, **kwargs: asyncio.sleep(0) + lambda *args, **kwargs: CrossSync._Sync_Impl.yield_to_event_loop() ) kwargs["instance_id"] = kwargs.get( "instance_id", args[0] if args else "instance" @@ -2510,7 +2526,7 @@ def test_table_ctor(self): default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, ) - asyncio.sleep(0) + CrossSync._Sync_Impl.yield_to_event_loop() assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id assert table.app_profile_id == expected_app_profile_id @@ -2551,7 +2567,7 @@ def test_table_ctor_defaults(self): client = self._make_client() assert not client._active_instances table = Table(client, expected_instance_id, expected_table_id) - asyncio.sleep(0) + CrossSync._Sync_Impl.yield_to_event_loop() assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id assert table.app_profile_id is None @@ -2695,9 +2711,7 @@ def test_customizable_retryable_errors( def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): """check that all requests attach proper metadata headers""" profile = "profile" if include_app_profile else None - with mock.patch.object( - BigtableClient, gapic_fn, mock.AsyncMock() - ) as gapic_mock: + with mock.patch.object(BigtableClient, gapic_fn, mock.Mock()) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") with self._make_client() as client: table = Table(client, "instance-id", "table-id", profile) From ba351e3b08ea5ae671b626cd8b1bf238e56294b0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Jun 2024 15:34:04 -0600 Subject: [PATCH 107/360] dded simplified transformer to crosssync --- .../cloud/bigtable/data/_sync/cross_sync.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 9a5fc5a08..5e04f160c 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -275,3 +275,33 @@ def create_task( @staticmethod def yield_to_event_loop() -> None: pass + +import ast + +class CrossSyncTransformer(ast.NodeTransformer): + pass + +if __name__ == "__main__": + import os + import glob + import importlib + import inspect + import itertools + import black + import autoflake + # find all cross_sync decorated classes + search_root = sys.argv[1] + found_files = [path.replace("/", ".")[:-3] for path in glob.glob(search_root + "/**/*.py", recursive=True)] + found_classes = itertools.chain.from_iterable([inspect.getmembers(importlib.import_module(path), inspect.isclass) for path in found_files]) + cross_sync_classes = [(name, cls) for name, cls in found_classes if hasattr(cls, "cross_sync_enabled")] + # convert files + file_buffers = {} + for cls_name, cls in cross_sync_classes: + ast_tree = ast.parse(inspect.getsource(cls)) + transformed_tree = CrossSyncTransformer().visit(ast_tree) + file_buffers[cls.cross_sync_file_path] = file_buffers.get(cls.cross_sync_file_path, "") + ast.unparse(transformed_tree) + # write to disk + for file_path, buffer in file_buffers.items(): + # cleaned = black.format_str(autoflake.fix_code(buffer, remove_all_unused_imports=True), mode=black.FileMode()) + with open(file_path, "w") as f: + f.write(buffer) From d6fac8efce6fb6f50dc19ba1d464eebbd2abfee7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Jun 2024 16:02:59 -0600 Subject: [PATCH 108/360] strip out basic sync code --- .../cloud/bigtable/data/_sync/cross_sync.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 5e04f160c..37f16b859 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -279,7 +279,43 @@ def yield_to_event_loop() -> None: import ast class CrossSyncTransformer(ast.NodeTransformer): - pass + + def visit_Await(self, node): + return self.visit(node.value) + + def visit_AsyncFor(self, node): + return ast.copy_location( + ast.For( + self.visit(node.target), + self.visit(node.iter), + [self.visit(stmt) for stmt in node.body], + [self.visit(stmt) for stmt in node.orelse], + ), + node, + ) + + def visit_AsyncWith(self, node): + return ast.copy_location( + ast.With( + [self.visit(item) for item in node.items], + [self.visit(stmt) for stmt in node.body], + ), + node, + ) + + def visit_AsyncFunctionDef(self, node): + return ast.copy_location( + ast.FunctionDef( + node.name, + self.visit(node.args), + [self.visit(stmt) for stmt in node.body], + [self.visit(decorator) for decorator in node.decorator_list], + node.returns and self.visit(node.returns), + ), + node, + ) + + if __name__ == "__main__": import os From d25a517ac6a98d4bb8751e43564d2a2d70642be5 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Jun 2024 16:44:51 -0600 Subject: [PATCH 109/360] include headers --- .../cloud/bigtable/data/_sync/cross_sync.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 37f16b859..2c5b12fe8 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -315,7 +315,23 @@ def visit_AsyncFunctionDef(self, node): node, ) +header = """ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is automatically generated by CrossSync. Do not edit manually. +""" if __name__ == "__main__": import os @@ -335,9 +351,9 @@ def visit_AsyncFunctionDef(self, node): for cls_name, cls in cross_sync_classes: ast_tree = ast.parse(inspect.getsource(cls)) transformed_tree = CrossSyncTransformer().visit(ast_tree) - file_buffers[cls.cross_sync_file_path] = file_buffers.get(cls.cross_sync_file_path, "") + ast.unparse(transformed_tree) + file_buffers[cls.cross_sync_file_path] = file_buffers.get(cls.cross_sync_file_path, header) + ast.unparse(transformed_tree) + "\n" # write to disk for file_path, buffer in file_buffers.items(): - # cleaned = black.format_str(autoflake.fix_code(buffer, remove_all_unused_imports=True), mode=black.FileMode()) + cleaned = black.format_str(autoflake.fix_code(buffer, remove_all_unused_imports=True), mode=black.FileMode()) with open(file_path, "w") as f: f.write(buffer) From 5dd32fdef17213593c5de3dfcd9110375bc1b0d5 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 2 Jul 2024 14:27:25 -0600 Subject: [PATCH 110/360] add file imports --- google/cloud/bigtable/data/_sync/cross_sync.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 2c5b12fe8..f4bbedb04 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -63,9 +63,11 @@ def sync_output( sync_path: str, replace_symbols: dict["str", "str" | None ] | None = None, mypy_ignore: list[str] | None = None, + include_file_imports: bool = True, ): replace_symbols = replace_symbols or {} mypy_ignore = mypy_ignore or [] + include_file_imports = include_file_imports # return the async class unchanged def decorator(async_cls): @@ -76,6 +78,7 @@ def decorator(async_cls): async_cls.cross_sync_file_path = "/".join(sync_path.split(".")[:-1]) + ".py" async_cls.cross_sync_replace_symbols = replace_symbols async_cls.cross_sync_mypy_ignore = mypy_ignore + async_cls.include_file_imports = include_file_imports return async_cls return decorator @@ -344,11 +347,17 @@ def visit_AsyncFunctionDef(self, node): # find all cross_sync decorated classes search_root = sys.argv[1] found_files = [path.replace("/", ".")[:-3] for path in glob.glob(search_root + "/**/*.py", recursive=True)] - found_classes = itertools.chain.from_iterable([inspect.getmembers(importlib.import_module(path), inspect.isclass) for path in found_files]) - cross_sync_classes = [(name, cls) for name, cls in found_classes if hasattr(cls, "cross_sync_enabled")] + found_classes = list(itertools.chain.from_iterable([inspect.getmembers(importlib.import_module(path), inspect.isclass) for path in found_files])) + cross_sync_classes = {name for name, cls in found_classes if hasattr(cls, "cross_sync_enabled")} # convert files file_buffers = {} - for cls_name, cls in cross_sync_classes: + for cls_name, cls in [entry for entry in found_classes if entry[0] in cross_sync_classes]: + if cls.include_file_imports: + # add imports if requested + with open(inspect.getfile(cls)) as f: + full_ast = ast.parse(f.read()) + imports = [node for node in full_ast.body if isinstance(node, (ast.Import, ast.ImportFrom, ast.If))] + file_buffers[cls.cross_sync_file_path] = file_buffers.get(cls.cross_sync_file_path, header) + "\n".join([ast.unparse(node) for node in imports]) + "\n" ast_tree = ast.parse(inspect.getsource(cls)) transformed_tree = CrossSyncTransformer().visit(ast_tree) file_buffers[cls.cross_sync_file_path] = file_buffers.get(cls.cross_sync_file_path, header) + ast.unparse(transformed_tree) + "\n" From cd40ba9522ab02b8041c548c07ab59e413b81ae9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 10:32:53 -0600 Subject: [PATCH 111/360] refactoring --- .../cloud/bigtable/data/_sync/cross_sync.py | 151 ++++++++++-------- .../cloud/bigtable/data/_sync/transformers.py | 96 +++++++++++ 2 files changed, 177 insertions(+), 70 deletions(-) create mode 100644 google/cloud/bigtable/data/_sync/transformers.py diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index f4bbedb04..6be4ad2a6 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -279,62 +279,80 @@ def create_task( def yield_to_event_loop() -> None: pass +from dataclasses import dataclass, field +from typing import ClassVar import ast -class CrossSyncTransformer(ast.NodeTransformer): - - def visit_Await(self, node): - return self.visit(node.value) - - def visit_AsyncFor(self, node): - return ast.copy_location( - ast.For( - self.visit(node.target), - self.visit(node.iter), - [self.visit(stmt) for stmt in node.body], - [self.visit(stmt) for stmt in node.orelse], - ), - node, - ) - - def visit_AsyncWith(self, node): - return ast.copy_location( - ast.With( - [self.visit(item) for item in node.items], - [self.visit(stmt) for stmt in node.body], - ), - node, +from google.cloud.bigtable.data._sync import transformers + +@dataclass +class CrossSyncArtifact: + file_path: str + imports: list[ast.Import | ast.ImportFrom] = field(default_factory=list) + converted_classes: dict[type, ast.ClassDef] = field(default_factory=dict) + _instances: ClassVar[dict[str, CrossSyncArtifact]] = {} + + def __hash__(self): + return hash(self.file_path) + + def render(self, with_black=True) -> str: + full_str = ( + "# Copyright 2024 Google LLC\n" + "#\n" + '# Licensed under the Apache License, Version 2.0 (the "License");\n' + '# you may not use this file except in compliance with the License.\n' + '# You may obtain a copy of the License at\n' + '#\n' + '# http://www.apache.org/licenses/LICENSE-2.0\n' + '#\n' + '# Unless required by applicable law or agreed to in writing, software\n' + '# distributed under the License is distributed on an "AS IS" BASIS,\n' + '# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n' + '# See the License for the specific language governing permissions and\n' + '# limitations under the License.\n' + '#\n' + '# This file is automatically generated by CrossSync. Do not edit manually.\n' ) + full_str += "\n".join([ast.unparse(node) for node in self.imports]) + full_str += "\n\n" + full_str += "\n".join([ast.unparse(node) for node in self.converted_classes.values()]) + if with_black: + cleaned = black.format_str(autoflake.fix_code(full_str, remove_all_unused_imports=True), mode=black.FileMode()) + return cleaned + else: + return full_str - def visit_AsyncFunctionDef(self, node): - return ast.copy_location( - ast.FunctionDef( - node.name, - self.visit(node.args), - [self.visit(stmt) for stmt in node.body], - [self.visit(decorator) for decorator in node.decorator_list], - node.returns and self.visit(node.returns), - ), - node, - ) - -header = """ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This file is automatically generated by CrossSync. Do not edit manually. -""" + @classmethod + def get_for_path(cls, path: str) -> CrossSyncArtifact: + if path not in cls._instances: + cls._instances[path] = CrossSyncArtifact(path) + return cls._instances[path] + + def add_class(self, cls): + if cls in self.converted_classes: + return + crosssync_converter = transformers.SymbolReplacer({"CrossSync": "CrossSync._Sync_Impl"}) + # convert class + cls_node = ast.parse(inspect.getsource(cls)).body[0] + # update name + cls_node.name = cls.cross_sync_class_name + # remove cross_sync decorator + if hasattr(cls_node, "decorator_list"): + cls_node.decorator_list = [d for d in cls_node.decorator_list if not isinstance(d, ast.Call) or not isinstance(d.func, ast.Attribute) or not isinstance(d.func.value, ast.Name) or d.func.value.id != "CrossSync"] + # do ast transformations + converted = transformers.AsyncToSync().visit(cls_node) + if cls.cross_sync_replace_symbols: + converted = transformers.SymbolReplacer(cls.cross_sync_replace_symbols).visit(converted) + converted = crosssync_converter.visit(converted) + converted = transformers.HandleCrossSyncDecorators().visit(converted) + self.converted_classes[cls] = converted + # add imports for added class if required + if cls.include_file_imports and not self.imports: + with open(inspect.getfile(cls)) as f: + full_ast = ast.parse(f.read()) + for node in full_ast.body: + if isinstance(node, (ast.Import, ast.ImportFrom, ast.If)): + self.imports.append(crosssync_converter.visit(node)) if __name__ == "__main__": import os @@ -348,21 +366,14 @@ def visit_AsyncFunctionDef(self, node): search_root = sys.argv[1] found_files = [path.replace("/", ".")[:-3] for path in glob.glob(search_root + "/**/*.py", recursive=True)] found_classes = list(itertools.chain.from_iterable([inspect.getmembers(importlib.import_module(path), inspect.isclass) for path in found_files])) - cross_sync_classes = {name for name, cls in found_classes if hasattr(cls, "cross_sync_enabled")} - # convert files - file_buffers = {} - for cls_name, cls in [entry for entry in found_classes if entry[0] in cross_sync_classes]: - if cls.include_file_imports: - # add imports if requested - with open(inspect.getfile(cls)) as f: - full_ast = ast.parse(f.read()) - imports = [node for node in full_ast.body if isinstance(node, (ast.Import, ast.ImportFrom, ast.If))] - file_buffers[cls.cross_sync_file_path] = file_buffers.get(cls.cross_sync_file_path, header) + "\n".join([ast.unparse(node) for node in imports]) + "\n" - ast_tree = ast.parse(inspect.getsource(cls)) - transformed_tree = CrossSyncTransformer().visit(ast_tree) - file_buffers[cls.cross_sync_file_path] = file_buffers.get(cls.cross_sync_file_path, header) + ast.unparse(transformed_tree) + "\n" - # write to disk - for file_path, buffer in file_buffers.items(): - cleaned = black.format_str(autoflake.fix_code(buffer, remove_all_unused_imports=True), mode=black.FileMode()) - with open(file_path, "w") as f: - f.write(buffer) + file_obj_set = set() + for name, cls in found_classes: + if hasattr(cls, "cross_sync_enabled"): + file_obj = CrossSyncArtifact.get_for_path(cls.cross_sync_file_path) + file_obj.add_class(cls) + file_obj_set.add(file_obj) + for file_obj in file_obj_set: + with open(file_obj.file_path, "w") as f: + f.write(file_obj.render()) + + diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py new file mode 100644 index 000000000..7197c06f1 --- /dev/null +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -0,0 +1,96 @@ +import ast + +class SymbolReplacer(ast.NodeTransformer): + + def __init__(self, replacements): + self.replacements = replacements + + def visit_Name(self, node): + if node.id in self.replacements: + node.id = self.replacements[node.id] + return node + + def visit_Attribute(self, node): + return ast.copy_location( + ast.Attribute( + self.visit(node.value), + self.replacements.get(node.attr, node.attr), + node.ctx, + ), + node, + ) + + +class AsyncToSync(ast.NodeTransformer): + + def visit_Await(self, node): + return self.visit(node.value) + + def visit_AsyncFor(self, node): + return ast.copy_location( + ast.For( + self.visit(node.target), + self.visit(node.iter), + [self.visit(stmt) for stmt in node.body], + [self.visit(stmt) for stmt in node.orelse], + ), + node, + ) + + def visit_AsyncWith(self, node): + return ast.copy_location( + ast.With( + [self.visit(item) for item in node.items], + [self.visit(stmt) for stmt in node.body], + ), + node, + ) + + def visit_AsyncFunctionDef(self, node): + return ast.copy_location( + ast.FunctionDef( + node.name, + self.visit(node.args), + [self.visit(stmt) for stmt in node.body], + [self.visit(decorator) for decorator in node.decorator_list], + node.returns and self.visit(node.returns), + ), + node, + ) + + def visit_ListComp(self, node): + # replace [x async for ...] with [x for ...] + new_generators = [] + for generator in node.generators: + if generator.is_async: + new_generators.append( + ast.copy_location( + ast.comprehension( + self.visit(generator.target), + self.visit(generator.iter), + [self.visit(i) for i in generator.ifs], + False, + ), + generator, + ) + ) + else: + new_generators.append(generator) + node.generators = new_generators + return ast.copy_location( + ast.ListComp( + self.visit(node.elt), + [self.visit(gen) for gen in node.generators], + ), + node, + ) + +class HandleCrossSyncDecorators(ast.NodeTransformer): + + def visit_FunctionDef(self, node): + if hasattr(node, "decorator_list"): + revised_list = [d for d in node.decorator_list if "CrossSync" not in ast.dump(d)] + node.decorator_list = revised_list + return node + + From fd639e4995250782f00ec5bde9bfe5656aced5d5 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 10:46:37 -0600 Subject: [PATCH 112/360] fixed docstrings --- .../cloud/bigtable/data/_sync/transformers.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index 7197c06f1..7e8b081ce 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -20,6 +20,25 @@ def visit_Attribute(self, node): node, ) + def update_docstring(self, docstring): + """ + Update docstring to replace any key words in the replacements dict + """ + if not docstring: + return docstring + for key_word, replacement in self.replacements.items(): + docstring = docstring.replace(f" {key_word} ", f" {replacement} ") + return docstring + + def visit_FunctionDef(self, node): + # replace docstring + docstring = self.update_docstring(ast.get_docstring(node)) + if isinstance(node.body[0], ast.Expr) and isinstance( + node.body[0].value, ast.Str + ): + node.body[0].value.s = docstring + return self.generic_visit(node) + class AsyncToSync(ast.NodeTransformer): From c1053e9f4a03e36ff13c0ee8d66e82b3125e0045 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 10:51:54 -0600 Subject: [PATCH 113/360] made changes to docstring format in generated files --- .../cloud/bigtable/data/_async/_read_rows.py | 1 - .../cloud/bigtable/data/_sync/_mutate_rows.py | 43 +++-- .../cloud/bigtable/data/_sync/_read_rows.py | 61 ++++--- google/cloud/bigtable/data/_sync/client.py | 150 +++++++----------- .../bigtable/data/_sync/mutations_batcher.py | 92 ++++------- 5 files changed, 135 insertions(+), 212 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index d46b9ec6a..7c3b27ad3 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -56,7 +56,6 @@ def __init__(self, chunk): replace_symbols={ "AsyncIterable": "Iterable", "StopAsyncIteration": "StopIteration", - "Awaitable": None, "TableAsync": "Table", "__aiter__": "__iter__", "__anext__": "__next__", diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index 738df0668..75beeca75 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -12,35 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# This file is automatically generated by sync_surface_generator.py. Do not edit. - - +# This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations from typing import Sequence, TYPE_CHECKING import functools - from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries -from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _retry_exception_factory -from google.cloud.bigtable.data._sync.cross_sync import CrossSync -from google.cloud.bigtable.data.exceptions import FailedMutationEntryError +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.exceptions import RetryExceptionGroup -from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if not CrossSync._Sync_Impl.is_async: + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry if CrossSync._Sync_Impl.is_async: - pass + from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) else: - from google.cloud.bigtable.data._sync.client import Table - from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient -if not CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto + pass class _MutateRowsOperation: @@ -103,12 +102,10 @@ def __init__( self.errors: dict[int, list[Exception]] = {} def start(self): - """ - Start the operation, and run until completion + """Start the operation, and run until completion Raises: - MutationsExceptionGroup: if any mutations failed - """ + MutationsExceptionGroup: if any mutations failed""" try: self._operation() except Exception as exc: @@ -132,14 +129,12 @@ def start(self): raise MutationsExceptionGroup(all_errors, len(self.mutations)) def _run_attempt(self): - """ - Run a single attempt of the mutate_rows rpc. + """Run a single attempt of the mutate_rows rpc. Raises: _MutateRowsIncomplete: if there are failed mutations eligible for retry after the attempt is complete - GoogleAPICallError: if the gapic rpc fails - """ + GoogleAPICallError: if the gapic rpc fails""" request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] active_request_indices = { req_idx: orig_idx for req_idx, orig_idx in enumerate(self.remaining_indices) @@ -174,15 +169,13 @@ def _run_attempt(self): raise _MutateRowsIncomplete def _handle_entry_error(self, idx: int, exc: Exception): - """ - Add an exception to the list of exceptions for a given mutation index, + """Add an exception to the list of exceptions for a given mutation index, and add the index to the list of remaining indices if the exception is retryable. Args: idx: the index of the mutation that failed - exc: the exception to add to the list - """ + exc: the exception to add to the list""" entry = self.mutations[idx].entry self.errors.setdefault(idx, []).append(exc) if ( diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index f46b80f4c..4fb7fea7a 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -12,24 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# This file is automatically generated by sync_surface_generator.py. Do not edit. - - +# This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from typing import Sequence - -from google.api_core import retry as retries -from google.api_core.retry import exponential_sleep_generator -from google.cloud.bigtable.data import _helpers -from google.cloud.bigtable.data._sync.cross_sync import CrossSync -from google.cloud.bigtable.data.exceptions import InvalidChunk -from google.cloud.bigtable.data.exceptions import _RowSetComplete -from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery -from google.cloud.bigtable.data.row import Row, Cell +from typing import TYPE_CHECKING, Awaitable, Sequence from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB -from google.cloud.bigtable_v2.types import RowRange as RowRangePB from google.cloud.bigtable_v2.types import RowSet as RowSetPB +from google.cloud.bigtable_v2.types import RowRange as RowRangePB +from google.cloud.bigtable.data.row import Row, Cell +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.exceptions import _RowSetComplete +from google.cloud.bigtable.data import _helpers +from google.api_core import retry as retries +from google.api_core.retry import exponential_sleep_generator +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +if TYPE_CHECKING: + if CrossSync._Sync_Impl.is_async: + from google.cloud.bigtable.data._async.client import TableAsync + else: + from typing import Iterable class _ReadRowsOperation: @@ -90,12 +93,10 @@ def __init__( self._remaining_count: int | None = self.request.rows_limit or None def start_operation(self) -> Iterable[Row]: - """ - Start the read_rows operation, retrying on retryable errors. + """Start the read_rows operation, retrying on retryable errors. Yields: - Row: The next row in the stream - """ + Row: The next row in the stream""" return CrossSync._Sync_Impl.retry_target_stream( self._read_rows_attempt, self._predicate, @@ -105,15 +106,13 @@ def start_operation(self) -> Iterable[Row]: ) def _read_rows_attempt(self) -> Iterable[Row]: - """ - Attempt a single read_rows rpc call. + """Attempt a single read_rows rpc call. This function is intended to be wrapped by retry logic, which will call this function until it succeeds or a non-retryable error is raised. Yields: - Row: The next row in the stream - """ + Row: The next row in the stream""" if self._last_yielded_row_key is not None: try: self.request.rows = self._revise_request_rowset( @@ -138,14 +137,12 @@ def _read_rows_attempt(self) -> Iterable[Row]: def chunk_stream( self, stream: Iterable[ReadRowsResponsePB] ) -> Iterable[ReadRowsResponsePB.CellChunk]: - """ - process chunks out of raw read_rows stream + """process chunks out of raw read_rows stream Args: stream: the raw read_rows stream from the gapic client Yields: - ReadRowsResponsePB.CellChunk: the next chunk in the stream - """ + ReadRowsResponsePB.CellChunk: the next chunk in the stream""" for resp in stream: resp = resp._pb if resp.last_scanned_row_key: @@ -181,14 +178,12 @@ def chunk_stream( def merge_rows( chunks: Iterable[ReadRowsResponsePB.CellChunk] | None, ) -> Iterable[Row]: - """ - Merge chunks into rows + """Merge chunks into rows Args: chunks: the chunk stream to merge Yields: - Row: the next row in the stream - """ + Row: the next row in the stream""" if chunks is None: return it = chunks.__iter__() @@ -280,8 +275,7 @@ def merge_rows( @staticmethod def _revise_request_rowset(row_set: RowSetPB, last_seen_row_key: bytes) -> RowSetPB: - """ - Revise the rows in the request to avoid ones we've already processed. + """Revise the rows in the request to avoid ones we've already processed. Args: row_set: the row set from the request @@ -289,8 +283,7 @@ def _revise_request_rowset(row_set: RowSetPB, last_seen_row_key: bytes) -> RowSe Returns: RowSetPB: the new rowset after adusting for the last seen key Raises: - _RowSetComplete: if there are no rows left to process after the revision - """ + _RowSetComplete: if there are no rows left to process after the revision""" if row_set is None or (not row_set.row_ranges and (not row_set.row_keys)): last_seen = last_seen_row_key return RowSetPB(row_ranges=[RowRangePB(start_key_open=last_seen)]) diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 5ef6fd325..9e75637d3 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -12,47 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# This file is automatically generated by sync_surface_generator.py. Do not edit. - - +# This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from functools import partial -from grpc import Channel from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING import asyncio -import concurrent.futures -import os -import random import time import warnings - -from google.api_core import client_options as client_options_lib +import random +import os +import concurrent.futures +from functools import partial +from grpc import Channel +from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta +from google.cloud.bigtable_v2.services.bigtable.transports.base import ( + DEFAULT_CLIENT_INFO, +) +from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest +from google.cloud.client import ClientWithProject +from google.cloud.environment_vars import BIGTABLE_EMULATOR from google.api_core import retry as retries -from google.api_core.exceptions import Aborted from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import ServiceUnavailable +from google.api_core.exceptions import Aborted +import google.auth.credentials +import google.auth._default +from google.api_core import client_options as client_options_lib from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT +from google.cloud.bigtable.data.row import Row +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.exceptions import FailedQueryShardError +from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.cloud.bigtable.data import _helpers from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import _MB_SIZE -from google.cloud.bigtable.data._sync.cross_sync import CrossSync -from google.cloud.bigtable.data.exceptions import FailedQueryShardError -from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule -from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery -from google.cloud.bigtable.data.row import Row -from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter from google.cloud.bigtable.data.row_filters import RowFilter -from google.cloud.bigtable.data.row_filters import RowFilterChain from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter -from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.transports.base import ( - DEFAULT_CLIENT_INFO, -) -from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest -from google.cloud.client import ClientWithProject -from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter +from google.cloud.bigtable.data.row_filters import RowFilterChain +from google.cloud.bigtable.data._sync.cross_sync import CrossSync if CrossSync._Sync_Impl.is_async: pass @@ -71,8 +70,6 @@ if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery -import google.auth._default -import google.auth.credentials class BigtableDataClient(ClientWithProject): @@ -86,8 +83,7 @@ def __init__( | "google.api_core.client_options.ClientOptions" | None = None, ): - """ - Create a client instance for the Bigtable Data API + """Create a client instance for the Bigtable Data API Client should be created within an async context (running event loop) @@ -107,8 +103,7 @@ def __init__( on the client. API Endpoint should be set through client_options. Raises: RuntimeError: if called outside of an async context (no running event loop) - ValueError: if pool_size is less than 1 - """ + ValueError: if pool_size is less than 1""" transport_str = f"bt-{self._client_version()}-{pool_size}" transport = PooledBigtableGrpcTransport.with_fixed_size(pool_size) BigtableClientMeta._transport_registry[transport_str] = transport @@ -178,12 +173,10 @@ def _client_version() -> str: return f"{google.cloud.bigtable.__version__}-data" def _start_background_channel_refresh(self) -> None: - """ - Starts a background task to ping and warm each channel in the pool + """Starts a background task to ping and warm each channel in the pool Raises: - RuntimeError: if not called in an asyncio event loop - """ + RuntimeError: if not called in an asyncio event loop""" if ( not self._channel_refresh_tasks and (not self._emulator_host) @@ -217,8 +210,7 @@ def close(self, timeout: float | None = None): def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None ) -> list[BaseException | None]: - """ - Prepares the backend for requests on a channel + """Prepares the backend for requests on a channel Pings each Bigtable instance registered in `_active_instances` on the client @@ -261,8 +253,7 @@ def _manage_channel( refresh_interval_max: float = 60 * 45, grace_period: float = 60 * 10, ) -> None: - """ - Background coroutine that periodically refreshes and warms a grpc channel + """Background coroutine that periodically refreshes and warms a grpc channel The backend will automatically close channels after 60 minutes, so `refresh_interval` + `grace_period` should be < 60 minutes @@ -278,8 +269,7 @@ def _manage_channel( process in seconds. Actual interval will be a random value between `refresh_interval_min` and `refresh_interval_max` grace_period: time to allow previous channel to serve existing - requests before closing, in seconds - """ + requests before closing, in seconds""" first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) @@ -306,8 +296,7 @@ def _manage_channel( next_sleep = next_refresh - (time.monotonic() - start_timestamp) def _register_instance(self, instance_id: str, owner: Table) -> None: - """ - Registers an instance with the client, and warms the channel pool + """Registers an instance with the client, and warms the channel pool for the instance The client will periodically refresh grpc channel pool used to make requests, and new channels will be warmed for each registered instance @@ -317,8 +306,7 @@ def _register_instance(self, instance_id: str, owner: Table) -> None: instance_id: id of the instance to register. owner: table that owns the instance. Owners will be tracked in _instance_owners, and instances will only be unregistered when all - owners call _remove_instance_registration - """ + owners call _remove_instance_registration""" instance_name = self._gapic_client.instance_path(self.project, instance_id) instance_key = _helpers._WarmedInstanceKey( instance_name, owner.table_name, owner.app_profile_id @@ -333,8 +321,7 @@ def _register_instance(self, instance_id: str, owner: Table) -> None: self._start_background_channel_refresh() def _remove_instance_registration(self, instance_id: str, owner: Table) -> bool: - """ - Removes an instance from the client's registered instances, to prevent + """Removes an instance from the client's registered instances, to prevent warming new channels for the instance If instance_id is not registered, or is still in use by other tables, returns False @@ -345,8 +332,7 @@ def _remove_instance_registration(self, instance_id: str, owner: Table) -> bool: _instance_owners, and instances will only be unregistered when all owners call _remove_instance_registration Returns: - bool: True if instance was removed, else False - """ + bool: True if instance was removed, else False""" instance_name = self._gapic_client.instance_path(self.project, instance_id) instance_key = _helpers._WarmedInstanceKey( instance_name, owner.table_name, owner.app_profile_id @@ -361,8 +347,7 @@ def _remove_instance_registration(self, instance_id: str, owner: Table) -> bool: return False def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> Table: - """ - Returns a table instance for making data API requests. All arguments are passed + """Returns a table instance for making data API requests. All arguments are passed directly to the Table constructor. Args: @@ -445,8 +430,7 @@ def __init__( ServiceUnavailable, ), ): - """ - Initialize a Table instance + """Initialize a Table instance Must be created within an async context (running event loop) @@ -541,8 +525,7 @@ def read_rows_stream( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> Iterable[Row]: - """ - Read a set of rows from the table, based on the specified query. + """Read a set of rows from the table, based on the specified query. Returns an iterator to asynchronously stream back row data. Failed requests within operation_timeout will be retried based on the @@ -590,8 +573,7 @@ def read_rows( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: - """ - Read a set of rows from the table, based on the specified query. + """Read a set of rows from the table, based on the specified query. Retruns results as a list of Row objects when the request is complete. For streamed results, use read_rows_stream. @@ -638,8 +620,7 @@ def read_row( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> Row | None: - """ - Read a single row from the table, based on the specified key. + """Read a single row from the table, based on the specified key. Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. @@ -686,8 +667,7 @@ def read_rows_sharded( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: - """ - Runs a sharded query in parallel, then return the results in a single list. + """Runs a sharded query in parallel, then return the results in a single list. Results will be returned in the order of the input queries. This function is intended to be run on the results on a query.shard() call. @@ -714,8 +694,7 @@ def read_rows_sharded( list[Row]: a list of Rows returned by the query Raises: ShardedReadRowsExceptionGroup: if any of the queries failed - ValueError: if the query_list is empty - """ + ValueError: if the query_list is empty""" if not sharded_query: raise ValueError("empty sharded_query") operation_timeout, attempt_timeout = _helpers._get_timeouts( @@ -777,8 +756,7 @@ def row_exists( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> bool: - """ - Return a boolean indicating whether the specified row exists in the table. + """Return a boolean indicating whether the specified row exists in the table. uses the filters: chain(limit cells per row = 1, strip value) Args: @@ -823,8 +801,7 @@ def sample_row_keys( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> RowKeySamples: - """ - Return a set of RowKeySamples that delimit contiguous sections of the table of + """Return a set of RowKeySamples that delimit contiguous sections of the table of approximately equal size RowKeySamples output can be used with ReadRowsQuery.shard() to create a sharded query that @@ -895,8 +872,7 @@ def mutations_batcher( batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, ) -> MutationsBatcher: - """ - Returns a new mutations batcher instance. + """Returns a new mutations batcher instance. Can be used to iteratively add mutations that are flushed as a group, to avoid excess network calls @@ -941,8 +917,7 @@ def mutate_row( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ): - """ - Mutates a row atomically. + """Mutates a row atomically. Cells already present in the row are left unchanged unless explicitly changed by ``mutation``. @@ -970,8 +945,7 @@ def mutate_row( GoogleAPIError exceptions from any retries that failed google.api_core.exceptions.GoogleAPIError: raised on non-idempotent operations that cannot be safely retried. - ValueError: if invalid arguments are provided - """ + ValueError: if invalid arguments are provided""" operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) @@ -1012,8 +986,7 @@ def bulk_mutate_rows( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, ): - """ - Applies mutations for multiple rows in a single batched request. + """Applies mutations for multiple rows in a single batched request. Each individual RowMutationEntry is applied atomically, but separate entries may be applied in arbitrary order (even for entries targetting the same row) @@ -1041,8 +1014,7 @@ def bulk_mutate_rows( Raises: MutationsExceptionGroup: if one or more mutations fails Contains details about any failed entries in .exceptions - ValueError: if invalid arguments are provided - """ + ValueError: if invalid arguments are provided""" operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) @@ -1066,8 +1038,7 @@ def check_and_mutate_row( false_case_mutations: Mutation | list[Mutation] | None = None, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> bool: - """ - Mutates a row atomically based on the output of a predicate filter + """Mutates a row atomically based on the output of a predicate filter Non-idempotent operation: will not be retried @@ -1096,8 +1067,7 @@ def check_and_mutate_row( Returns: bool indicating whether the predicate was true or false Raises: - google.api_core.exceptions.GoogleAPIError: exceptions from grpc call - """ + google.api_core.exceptions.GoogleAPIError: exceptions from grpc call""" operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) if true_case_mutations is not None and ( not isinstance(true_case_mutations, list) @@ -1130,8 +1100,7 @@ def read_modify_write_row( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> Row: - """ - Reads and modifies a row atomically according to input ReadModifyWriteRules, + """Reads and modifies a row atomically according to input ReadModifyWriteRules, and returns the contents of all modified cells The new value for the timestamp is the greater of the existing timestamp or @@ -1151,8 +1120,7 @@ def read_modify_write_row( Row: a Row containing cell data that was modified as part of the operation Raises: google.api_core.exceptions.GoogleAPIError: exceptions from grpc call - ValueError: if invalid arguments are provided - """ + ValueError: if invalid arguments are provided""" operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) if operation_timeout <= 0: raise ValueError("operation_timeout must be greater than 0") @@ -1179,21 +1147,17 @@ def close(self): self.client._remove_instance_registration(self.instance_id, self) def __enter__(self): - """ - Implement async context manager protocol + """Implement async context manager protocol Ensure registration task has time to run, so that - grpc channels will be warmed for the specified instance - """ + grpc channels will be warmed for the specified instance""" if self._register_instance_future: self._register_instance_future return self def __exit__(self, exc_type, exc_val, exc_tb): - """ - Implement async context manager protocol + """Implement async context manager protocol Unregister this instance with the client, so that - grpc channels will no longer be warmed - """ + grpc channels will no longer be warmed""" self.close() diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 8fa12022a..c04e49afa 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -12,25 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# This file is automatically generated by sync_surface_generator.py. Do not edit. # mypy: disable-error-code="unreachable" - +# This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from collections import deque from typing import Sequence, TYPE_CHECKING import atexit -import concurrent.futures import warnings - -from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data._helpers import _MB_SIZE +from collections import deque +import concurrent.futures +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup +from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts -from google.cloud.bigtable.data._sync.cross_sync import CrossSync -from google.cloud.bigtable.data.exceptions import FailedMutationEntryError -from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup -from google.cloud.bigtable.data.mutations import Mutation +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable.data.mutations import Mutation +from google.cloud.bigtable.data._sync.cross_sync import CrossSync if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -125,15 +123,13 @@ def __init__( atexit.register(self._on_exit) def _timer_routine(self, interval: float | None) -> None: - """ - Set up a background task to flush the batcher every interval seconds + """Set up a background task to flush the batcher every interval seconds If interval is None, an empty future is returned Args: flush_interval: Automatically flush every flush_interval seconds. - If None, no time-based flushing is performed. - """ + If None, no time-based flushing is performed.""" if not interval or interval <= 0: return None while not self._closed.is_set(): @@ -144,15 +140,13 @@ def _timer_routine(self, interval: float | None) -> None: self._schedule_flush() def append(self, mutation_entry: RowMutationEntry): - """ - Add a new set of mutations to the internal queue + """Add a new set of mutations to the internal queue Args: mutation_entry: new entry to add to flush queue Raises: RuntimeError: if batcher is closed - ValueError: if an invalid mutation type is added - """ + ValueError: if an invalid mutation type is added""" if self._closed.is_set(): raise RuntimeError("Cannot append to closed MutationsBatcher") if isinstance(mutation_entry, Mutation): @@ -170,13 +164,11 @@ def append(self, mutation_entry: RowMutationEntry): CrossSync._Sync_Impl.yield_to_event_loop() def _schedule_flush(self) -> CrossSync._Sync_Impl.Future[None] | None: - """ - Update the flush task to include the latest staged entries + """Update the flush task to include the latest staged entries Returns: Future[None] | None: - future representing the background task, if started - """ + future representing the background task, if started""" if self._staged_entries: entries, self._staged_entries = (self._staged_entries, []) self._staged_count, self._staged_bytes = (0, 0) @@ -190,12 +182,10 @@ def _schedule_flush(self) -> CrossSync._Sync_Impl.Future[None] | None: return None def _flush_internal(self, new_entries: list[RowMutationEntry]): - """ - Flushes a set of mutations to the server, and updates internal state + """Flushes a set of mutations to the server, and updates internal state Args: - new_entries list of RowMutationEntry objects to flush - """ + new_entries list of RowMutationEntry objects to flush""" in_process_requests: list[ CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] ] = [] @@ -211,8 +201,7 @@ def _flush_internal(self, new_entries: list[RowMutationEntry]): def _execute_mutate_rows( self, batch: list[RowMutationEntry] ) -> list[FailedMutationEntryError]: - """ - Helper to execute mutation operation on a batch + """Helper to execute mutation operation on a batch Args: batch: list of RowMutationEntry objects to send to server @@ -221,8 +210,7 @@ def _execute_mutate_rows( Returns: list[FailedMutationEntryError]: list of FailedMutationEntryError objects for mutations that failed. - FailedMutationEntryError objects will not contain index information - """ + FailedMutationEntryError objects will not contain index information""" try: operation = _MutateRowsOperation( self._table.client._gapic_client, @@ -242,14 +230,12 @@ def _execute_mutate_rows( return [] def _add_exceptions(self, excs: list[Exception]): - """ - Add new list of exceptions to internal store. To avoid unbounded memory, + """Add new list of exceptions to internal store. To avoid unbounded memory, the batcher will store the first and last _exception_list_limit exceptions, and discard any in between. Args: - excs: list of exceptions to add to the internal store - """ + excs: list of exceptions to add to the internal store""" self._exceptions_since_last_raise += len(excs) if excs and len(self._oldest_exceptions) < self._exception_list_limit: addition_count = self._exception_list_limit - len(self._oldest_exceptions) @@ -259,12 +245,10 @@ def _add_exceptions(self, excs: list[Exception]): self._newest_exceptions.extend(excs[-self._exception_list_limit :]) def _raise_exceptions(self): - """ - Raise any unreported exceptions from background flush operations + """Raise any unreported exceptions from background flush operations Raises: - MutationsExceptionGroup: exception group with all unreported exceptions - """ + MutationsExceptionGroup: exception group with all unreported exceptions""" if self._oldest_exceptions or self._newest_exceptions: oldest, self._oldest_exceptions = (self._oldest_exceptions, []) newest = list(self._newest_exceptions) @@ -289,19 +273,15 @@ def __enter__(self): return self def __exit__(self, exc_type, exc, tb): - """ - Allow use of context manager API. + """Allow use of context manager API. - Flushes the batcher and cleans up resources. - """ + Flushes the batcher and cleans up resources.""" self.close() @property def closed(self) -> bool: - """ - Returns: - - True if the batcher is closed, False otherwise - """ + """Returns: + - True if the batcher is closed, False otherwise""" return self._closed.is_set() def close(self): @@ -328,8 +308,7 @@ def _wait_for_batch_results( *tasks: CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] | CrossSync._Sync_Impl.Future[None], ) -> list[Exception]: - """ - Takes in a list of futures representing _execute_mutate_rows tasks, + """Takes in a list of futures representing _execute_mutate_rows tasks, waits for them to complete, and returns a list of errors encountered. Args: @@ -389,8 +368,7 @@ def __init__(self, max_mutation_count: int, max_mutation_bytes: int): self._in_flight_mutation_bytes = 0 def _has_capacity(self, additional_count: int, additional_size: int) -> bool: - """ - Checks if there is capacity to send a new entry with the given size and count + """Checks if there is capacity to send a new entry with the given size and count FlowControl limits are not hard limits. If a single mutation exceeds the configured flow limits, it will be sent in a single batch when @@ -411,14 +389,12 @@ def _has_capacity(self, additional_count: int, additional_size: int) -> bool: def remove_from_flow( self, mutations: RowMutationEntry | list[RowMutationEntry] ) -> None: - """ - Removes mutations from flow control. This method should be called once + """Removes mutations from flow control. This method should be called once for each mutation that was sent to add_to_flow, after the corresponding operation is complete. Args: - mutations: mutation or list of mutations to remove from flow control - """ + mutations: mutation or list of mutations to remove from flow control""" if not isinstance(mutations, list): mutations = [mutations] total_count = sum((len(entry.mutations) for entry in mutations)) @@ -429,8 +405,7 @@ def remove_from_flow( self._capacity_condition.notify_all() def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): - """ - Generator function that registers mutations with flow control. As mutations + """Generator function that registers mutations with flow control. As mutations are accepted into the flow control, they are yielded back to the caller, to be sent in a batch. If the flow control is at capacity, the generator will block until there is capacity available. @@ -440,8 +415,7 @@ def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): Yields: list[RowMutationEntry]: list of mutations that have reserved space in the flow control. - Each batch contains at least one mutation. - """ + Each batch contains at least one mutation.""" if not isinstance(mutations, list): mutations = [mutations] start_idx = 0 From c925760e45d903cf7374e8d417ae500a7ab8bc5c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 10:53:23 -0600 Subject: [PATCH 114/360] simplified generator visit --- .../cloud/bigtable/data/_sync/transformers.py | 25 ++----------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index 7e8b081ce..93f203594 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -79,30 +79,9 @@ def visit_AsyncFunctionDef(self, node): def visit_ListComp(self, node): # replace [x async for ...] with [x for ...] - new_generators = [] for generator in node.generators: - if generator.is_async: - new_generators.append( - ast.copy_location( - ast.comprehension( - self.visit(generator.target), - self.visit(generator.iter), - [self.visit(i) for i in generator.ifs], - False, - ), - generator, - ) - ) - else: - new_generators.append(generator) - node.generators = new_generators - return ast.copy_location( - ast.ListComp( - self.visit(node.elt), - [self.visit(gen) for gen in node.generators], - ), - node, - ) + generator.is_async = False + return self.generic_visit(node) class HandleCrossSyncDecorators(ast.NodeTransformer): From 9f9ec0f0916eec5df1633331712a8f68afb2724f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 10:58:42 -0600 Subject: [PATCH 115/360] replace string types --- google/cloud/bigtable/data/_sync/transformers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index 93f203594..610e8cec1 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -39,6 +39,11 @@ def visit_FunctionDef(self, node): node.body[0].value.s = docstring return self.generic_visit(node) + def visit_Str(self, node): + """Used to replace string type annotations""" + node.s = self.replacements.get(node.s, node.s) + return node + class AsyncToSync(ast.NodeTransformer): From e5168a133dbe83780ae6338fea5853456f09cb96 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 11:04:08 -0600 Subject: [PATCH 116/360] add mypyy disabling --- google/cloud/bigtable/data/_sync/cross_sync.py | 6 ++++++ google/cloud/bigtable/data/_sync/mutations_batcher.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 6be4ad2a6..e0829afe4 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -290,6 +290,7 @@ class CrossSyncArtifact: file_path: str imports: list[ast.Import | ast.ImportFrom] = field(default_factory=list) converted_classes: dict[type, ast.ClassDef] = field(default_factory=dict) + mypy_ignores: list[str] = field(default_factory=list) _instances: ClassVar[dict[str, CrossSyncArtifact]] = {} def __hash__(self): @@ -313,6 +314,8 @@ def render(self, with_black=True) -> str: '#\n' '# This file is automatically generated by CrossSync. Do not edit manually.\n' ) + if self.mypy_ignores: + full_str += f'\n# mypy: disable-error-code="{",".join(self.mypy_ignores)}"\n\n' full_str += "\n".join([ast.unparse(node) for node in self.imports]) full_str += "\n\n" full_str += "\n".join([ast.unparse(node) for node in self.converted_classes.values()]) @@ -353,6 +356,9 @@ def add_class(self, cls): for node in full_ast.body: if isinstance(node, (ast.Import, ast.ImportFrom, ast.If)): self.imports.append(crosssync_converter.visit(node)) + # add mypy ignore if required + if cls.cross_sync_mypy_ignore: + self.mypy_ignores.extend(cls.cross_sync_mypy_ignore) if __name__ == "__main__": import os diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index c04e49afa..524dab78d 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# mypy: disable-error-code="unreachable" # This file is automatically generated by CrossSync. Do not edit manually. + +# mypy: disable-error-code="unreachable" + from __future__ import annotations from typing import Sequence, TYPE_CHECKING import atexit From 9bd13b32ada062c3cdc906d4a5e27f07e18b7238 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 11:04:40 -0600 Subject: [PATCH 117/360] regenerated imports --- google/cloud/bigtable/data/_sync/_mutate_rows.py | 8 +++----- google/cloud/bigtable/data/_sync/_read_rows.py | 3 ++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index 75beeca75..d65cf3c61 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -34,12 +34,10 @@ from google.cloud.bigtable.data.mutations import RowMutationEntry if CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) - else: pass + else: + from google.cloud.bigtable.data._sync.client import Table + from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient class _MutateRowsOperation: diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index 4fb7fea7a..62755e6e0 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -30,8 +30,9 @@ if TYPE_CHECKING: if CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable.data._async.client import TableAsync + pass else: + from google.cloud.bigtable.data._sync.client import Table from typing import Iterable From 6895968675ce51c74af070fa92c0cbd736cb34e3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 11:24:36 -0600 Subject: [PATCH 118/360] got rename_sync decorator working --- google/cloud/bigtable/data/_async/client.py | 2 ++ google/cloud/bigtable/data/_sync/transformers.py | 11 +++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 00a3ee419..bd86923cb 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -478,10 +478,12 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs """ return TableAsync(self, instance_id, table_id, *args, **kwargs) + @CrossSync.rename_sync("__enter__") async def __aenter__(self): self._start_background_channel_refresh() return self + @CrossSync.rename_sync("__exit__") async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index 610e8cec1..a65078771 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -92,8 +92,15 @@ class HandleCrossSyncDecorators(ast.NodeTransformer): def visit_FunctionDef(self, node): if hasattr(node, "decorator_list"): - revised_list = [d for d in node.decorator_list if "CrossSync" not in ast.dump(d)] - node.decorator_list = revised_list + found_list, node.decorator_list = node.decorator_list, [] + for decorator in found_list: + if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name) and "CrossSync" in decorator.func.value.id: + decorator_type = decorator.func.attr + if decorator_type == "rename_sync": + node.name = decorator.args[0].value + else: + # add non-crosssync decorators back + node.decorator_list.append(decorator) return node From 09c090d115fd9f3bc2dbc63922d4d920375c2759 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 11:39:11 -0600 Subject: [PATCH 119/360] made convert decorator --- google/cloud/bigtable/data/_async/client.py | 10 ++++------ google/cloud/bigtable/data/_async/mutations_batcher.py | 4 ++-- google/cloud/bigtable/data/_sync/cross_sync.py | 2 +- google/cloud/bigtable/data/_sync/transformers.py | 9 +++++++-- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index bd86923cb..1b0dc0239 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -113,8 +113,6 @@ @CrossSync.sync_output( "google.cloud.bigtable.data._sync.client.BigtableDataClient", replace_symbols={ - "__aenter__": "__enter__", - "__aexit__": "__exit__", "TableAsync": "Table", "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", "BigtableAsyncClient": "BigtableClient", @@ -478,12 +476,12 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs """ return TableAsync(self, instance_id, table_id, *args, **kwargs) - @CrossSync.rename_sync("__enter__") + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): self._start_background_channel_refresh() return self - @CrossSync.rename_sync("__exit__") + @CrossSync.convert(sync_name="__exit__", replace_symbols={"__aexit__": "__exit__"}) async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) @@ -1298,7 +1296,7 @@ async def close(self): self._register_instance_future.cancel() await self.client._remove_instance_registration(self.instance_id, self) - @CrossSync.rename_sync("__enter__") + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """ Implement async context manager protocol @@ -1310,7 +1308,7 @@ async def __aenter__(self): await self._register_instance_future return self - @CrossSync.rename_sync("__exit__") + @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc_val, exc_tb): """ Implement async context manager protocol diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 643e90126..1f5b26bbb 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -441,12 +441,12 @@ def _raise_exceptions(self): entry_count=entry_count, ) - @CrossSync.rename_sync("__enter__") + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """Allow use of context manager API""" return self - @CrossSync.rename_sync("__exit__") + @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc, tb): """ Allow use of context manager API. diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index e0829afe4..70781de2d 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -47,7 +47,7 @@ class CrossSync: generated_replacements: dict[type, str] = {} @staticmethod - def rename_sync(*args, **kwargs): + def convert(*, sync_name: str|None=None, replace_symbols: dict[str, str]|None=None): def decorator(func): return func diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index a65078771..fcf9092ea 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -96,8 +96,13 @@ def visit_FunctionDef(self, node): for decorator in found_list: if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name) and "CrossSync" in decorator.func.value.id: decorator_type = decorator.func.attr - if decorator_type == "rename_sync": - node.name = decorator.args[0].value + if decorator_type == "convert": + for subcommand in decorator.keywords: + if subcommand.arg == "sync_name": + node.name = subcommand.value.s + if subcommand.arg == "replace_symbols": + replacements = {subcommand.value.keys[i].s: subcommand.value.values[i].s for i in range(len(subcommand.value.keys))} + node = SymbolReplacer(replacements).visit(node) else: # add non-crosssync decorators back node.decorator_list.append(decorator) From ab28899789b1f2b21560c9280b15b4f5d36b36d9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 12:07:45 -0600 Subject: [PATCH 120/360] added CrossSync.Awaitable --- google/cloud/bigtable/data/_async/_mutate_rows.py | 6 +++--- google/cloud/bigtable/data/_async/_read_rows.py | 5 ++++- google/cloud/bigtable/data/_sync/_mutate_rows.py | 4 ++-- google/cloud/bigtable/data/_sync/_read_rows.py | 6 ++++-- google/cloud/bigtable/data/_sync/cross_sync.py | 4 +++- 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index dc2c81052..d3f12af3b 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -34,9 +34,6 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -if not CrossSync.is_async: - from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto - if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -60,6 +57,9 @@ class _EntryWithProto: entry: RowMutationEntry proto: types_pb.MutateRowsRequest.Entry +if not CrossSync.is_async: + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto + @CrossSync.sync_output( "google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 7c3b27ad3..09cc59e03 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -38,6 +38,7 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync + if TYPE_CHECKING: if CrossSync.is_async: from google.cloud.bigtable.data._async.client import TableAsync @@ -50,6 +51,8 @@ class _ResetRow(Exception): def __init__(self, chunk): self.chunk = chunk +if not CrossSync.is_async: + from google.cloud.bigtable.data._async._read_rows import _ResetRow @CrossSync.sync_output( "google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", @@ -173,7 +176,7 @@ def _read_rows_attempt(self) -> AsyncIterable[Row]: return self.merge_rows(chunked_stream) async def chunk_stream( - self, stream: Awaitable[AsyncIterable[ReadRowsResponsePB]] + self, stream: CrossSync.Awaitable[AsyncIterable[ReadRowsResponsePB]] ) -> AsyncIterable[ReadRowsResponsePB.CellChunk]: """ process chunks out of raw read_rows stream diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index d65cf3c61..11591c007 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -28,8 +28,6 @@ from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT from google.cloud.bigtable.data._sync.cross_sync import CrossSync -if not CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -38,6 +36,8 @@ else: from google.cloud.bigtable.data._sync.client import Table from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient +if not CrossSync._Sync_Impl.is_async: + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto class _MutateRowsOperation: diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index 62755e6e0..27d3417d7 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -14,7 +14,7 @@ # # This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from typing import TYPE_CHECKING, Awaitable, Sequence +from typing import TYPE_CHECKING, Sequence from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB from google.cloud.bigtable_v2.types import RowSet as RowSetPB @@ -34,6 +34,8 @@ else: from google.cloud.bigtable.data._sync.client import Table from typing import Iterable +if not CrossSync._Sync_Impl.is_async: + from google.cloud.bigtable.data._async._read_rows import _ResetRow class _ReadRowsOperation: @@ -136,7 +138,7 @@ def _read_rows_attempt(self) -> Iterable[Row]: return self.merge_rows(chunked_stream) def chunk_stream( - self, stream: Iterable[ReadRowsResponsePB] + self, stream: CrossSync._Sync_Impl.Awaitable[Iterable[ReadRowsResponsePB]] ) -> Iterable[ReadRowsResponsePB.CellChunk]: """process chunks out of raw read_rows stream diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 70781de2d..97d5b2e7b 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -14,7 +14,7 @@ # from __future__ import annotations -from typing import TypeVar, Any, Awaitable, Callable, Coroutine, Sequence, TYPE_CHECKING +from typing import TypeVar, Any, Awaitable, Callable, Coroutine, Sequence, Union, TYPE_CHECKING import asyncio import sys @@ -43,6 +43,7 @@ class CrossSync: Task: TypeAlias = asyncio.Task Event: TypeAlias = asyncio.Event Semaphore: TypeAlias = asyncio.Semaphore + Awaitable: TypeAlias = Awaitable generated_replacements: dict[type, str] = {} @@ -201,6 +202,7 @@ class _Sync_Impl: Task: TypeAlias = concurrent.futures.Future Event: TypeAlias = threading.Event Semaphore: TypeAlias = threading.Semaphore + Awaitable: TypeAlias = Union[T] generated_replacements: dict[type, str] = {} From 40678e05eddb35e1055c0c11e3d2ec69bfe4bb6f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 12:22:32 -0600 Subject: [PATCH 121/360] more targeted replacements --- .../bigtable/data/_async/_mutate_rows.py | 8 ++--- .../cloud/bigtable/data/_async/_read_rows.py | 27 +++++++---------- google/cloud/bigtable/data/_async/client.py | 29 +++++++++---------- .../bigtable/data/_async/mutations_batcher.py | 16 +++++----- .../cloud/bigtable/data/_sync/_read_rows.py | 20 +++++++------ .../cloud/bigtable/data/_sync/cross_sync.py | 12 +++++++- .../bigtable/data/_sync/mutations_batcher.py | 5 +++- 7 files changed, 60 insertions(+), 57 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index d3f12af3b..47715cb99 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -63,10 +63,6 @@ class _EntryWithProto: @CrossSync.sync_output( "google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", - replace_symbols={ - "BigtableAsyncClient": "BigtableClient", - "TableAsync": "Table", - }, ) class _MutateRowsOperationAsync: """ @@ -87,6 +83,10 @@ class _MutateRowsOperationAsync: If not specified, the request will run until operation_timeout is reached. """ + @CrossSync.convert(replace_symbols={ + "BigtableAsyncClient": "BigtableClient", + "TableAsync": "Table", + }) def __init__( self, gapic_client: "BigtableAsyncClient", diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 09cc59e03..f786de98d 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -17,7 +17,6 @@ from typing import ( TYPE_CHECKING, - AsyncIterable, Awaitable, Sequence, ) @@ -44,7 +43,6 @@ from google.cloud.bigtable.data._async.client import TableAsync else: from google.cloud.bigtable.data._sync.client import Table # noqa: F401 - from typing import Iterable # noqa: F401 class _ResetRow(Exception): @@ -56,13 +54,6 @@ def __init__(self, chunk): @CrossSync.sync_output( "google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", - replace_symbols={ - "AsyncIterable": "Iterable", - "StopAsyncIteration": "StopIteration", - "TableAsync": "Table", - "__aiter__": "__iter__", - "__anext__": "__next__", - }, ) class _ReadRowsOperationAsync: """ @@ -95,6 +86,7 @@ class _ReadRowsOperationAsync: "_remaining_count", ) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def __init__( self, query: ReadRowsQuery, @@ -124,7 +116,7 @@ def __init__( self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None - def start_operation(self) -> AsyncIterable[Row]: + def start_operation(self) -> CrossSync.Iterable[Row]: """ Start the read_rows operation, retrying on retryable errors. @@ -139,7 +131,7 @@ def start_operation(self) -> AsyncIterable[Row]: exception_factory=_helpers._retry_exception_factory, ) - def _read_rows_attempt(self) -> AsyncIterable[Row]: + def _read_rows_attempt(self) -> CrossSync.Iterable[Row]: """ Attempt a single read_rows rpc call. This function is intended to be wrapped by retry logic, @@ -176,8 +168,8 @@ def _read_rows_attempt(self) -> AsyncIterable[Row]: return self.merge_rows(chunked_stream) async def chunk_stream( - self, stream: CrossSync.Awaitable[AsyncIterable[ReadRowsResponsePB]] - ) -> AsyncIterable[ReadRowsResponsePB.CellChunk]: + self, stream: CrossSync.Awaitable[CrossSync.Iterable[ReadRowsResponsePB]] + ) -> CrossSync.Iterable[ReadRowsResponsePB.CellChunk]: """ process chunks out of raw read_rows stream @@ -227,9 +219,10 @@ async def chunk_stream( current_key = None @staticmethod + @CrossSync.convert(replace_symbols={"__aiter__": "__iter__", "__anext__": "__next__"}) async def merge_rows( - chunks: AsyncIterable[ReadRowsResponsePB.CellChunk] | None, - ) -> AsyncIterable[Row]: + chunks: CrossSync.Iterable[ReadRowsResponsePB.CellChunk] | None, + ) -> CrossSync.Iterable[Row]: """ Merge chunks into rows @@ -245,7 +238,7 @@ async def merge_rows( while True: try: c = await it.__anext__() - except StopAsyncIteration: + except CrossSync.StopIteration: # stream complete return row_key = c.row_key @@ -338,7 +331,7 @@ async def merge_rows( ): raise InvalidChunk("reset row with data") continue - except StopAsyncIteration: + except CrossSync.StopIteration: raise InvalidChunk("premature end of stream") @staticmethod diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 1b0dc0239..9dd001543 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -112,14 +112,13 @@ @CrossSync.sync_output( "google.cloud.bigtable.data._sync.client.BigtableDataClient", - replace_symbols={ - "TableAsync": "Table", - "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", - "BigtableAsyncClient": "BigtableClient", - "AsyncPooledChannel": "PooledChannel", - }, ) class BigtableDataClientAsync(ClientWithProject): + @CrossSync.convert(replace_symbols={ + "BigtableAsyncClient": "BigtableClient", + "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", + "AsyncPooledChannel": "PooledChannel", + }) def __init__( self, *, @@ -375,6 +374,7 @@ async def _manage_channel( next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = next_refresh - (time.monotonic() - start_timestamp) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: """ Registers an instance with the client, and warms the channel pool @@ -405,6 +405,7 @@ async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: # refresh tasks aren't active. start them as background tasks self._start_background_channel_refresh() + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def _remove_instance_registration( self, instance_id: str, owner: TableAsync ) -> bool: @@ -435,6 +436,7 @@ async def _remove_instance_registration( except KeyError: return False + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: """ Returns a table instance for making data API requests. All arguments are passed @@ -487,16 +489,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) -@CrossSync.sync_output( - "google.cloud.bigtable.data._sync.client.Table", - replace_symbols={ - "AsyncIterable": "Iterable", - "MutationsBatcherAsync": "MutationsBatcher", - "BigtableDataClientAsync": "BigtableDataClient", - "_ReadRowsOperationAsync": "_ReadRowsOperation", - "_MutateRowsOperationAsync": "_MutateRowsOperation", - }, -) +@CrossSync.sync_output("google.cloud.bigtable.data._sync.client.Table") class TableAsync: """ Main Data API surface @@ -505,6 +498,7 @@ class TableAsync: each call """ + @CrossSync.convert(replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}) def __init__( self, client: BigtableDataClientAsync, @@ -625,6 +619,7 @@ def __init__( f"{self.__class__.__name__} must be created within an async event loop context." ) from e + @CrossSync.convert(replace_symbols={"AsyncIterable": "Iterable", "_ReadRowsOperationAsync": "_ReadRowsOperation"}) async def read_rows_stream( self, query: ReadRowsQuery, @@ -990,6 +985,7 @@ async def execute_rpc(): exception_factory=_helpers._retry_exception_factory, ) + @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) def mutations_batcher( self, *, @@ -1117,6 +1113,7 @@ async def mutate_row( exception_factory=_helpers._retry_exception_factory, ) + @CrossSync.convert(replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"}) async def bulk_mutate_rows( self, mutation_entries: list[RowMutationEntry], diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 1f5b26bbb..d723c4118 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -27,7 +27,6 @@ from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import _MB_SIZE -from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync from google.cloud.bigtable.data.mutations import ( _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, ) @@ -35,6 +34,11 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if CrossSync.is_async: + from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync +else: + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation + if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -43,9 +47,6 @@ from google.cloud.bigtable.data._async.client import TableAsync else: from google.cloud.bigtable.data._sync.client import Table # noqa: F401 - from google.cloud.bigtable.data._sync._mutate_rows import ( # noqa: F401 - _MutateRowsOperation, - ) @CrossSync.sync_output( @@ -179,11 +180,6 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] @CrossSync.sync_output( "google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", - replace_symbols={ - "TableAsync": "Table", - "_FlowControlAsync": "_FlowControl", - "_MutateRowsOperationAsync": "_MutateRowsOperation", - }, mypy_ignore=["unreachable"], ) class MutationsBatcherAsync: @@ -217,6 +213,7 @@ class MutationsBatcherAsync: Defaults to the Table's default_mutate_rows_retryable_errors. """ + @CrossSync.convert(replace_symbols={"TableAsync": "Table", "_FlowControlAsync": "_FlowControl"}) def __init__( self, table: TableAsync, @@ -361,6 +358,7 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): self._entries_processed_since_last_raise += len(new_entries) self._add_exceptions(found_exceptions) + @CrossSync.convert(replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"}) async def _execute_mutate_rows( self, batch: list[RowMutationEntry] ) -> list[FailedMutationEntryError]: diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index 27d3417d7..7a6d1300c 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -33,7 +33,6 @@ pass else: from google.cloud.bigtable.data._sync.client import Table - from typing import Iterable if not CrossSync._Sync_Impl.is_async: from google.cloud.bigtable.data._async._read_rows import _ResetRow @@ -95,7 +94,7 @@ def __init__( self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None - def start_operation(self) -> Iterable[Row]: + def start_operation(self) -> CrossSync._Sync_Impl.Iterable[Row]: """Start the read_rows operation, retrying on retryable errors. Yields: @@ -108,7 +107,7 @@ def start_operation(self) -> Iterable[Row]: exception_factory=_helpers._retry_exception_factory, ) - def _read_rows_attempt(self) -> Iterable[Row]: + def _read_rows_attempt(self) -> CrossSync._Sync_Impl.Iterable[Row]: """Attempt a single read_rows rpc call. This function is intended to be wrapped by retry logic, which will call this function until it succeeds or @@ -138,8 +137,11 @@ def _read_rows_attempt(self) -> Iterable[Row]: return self.merge_rows(chunked_stream) def chunk_stream( - self, stream: CrossSync._Sync_Impl.Awaitable[Iterable[ReadRowsResponsePB]] - ) -> Iterable[ReadRowsResponsePB.CellChunk]: + self, + stream: CrossSync._Sync_Impl.Awaitable[ + CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB] + ], + ) -> CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB.CellChunk]: """process chunks out of raw read_rows stream Args: @@ -179,8 +181,8 @@ def chunk_stream( @staticmethod def merge_rows( - chunks: Iterable[ReadRowsResponsePB.CellChunk] | None, - ) -> Iterable[Row]: + chunks: CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB.CellChunk] | None, + ) -> CrossSync._Sync_Impl.Iterable[Row]: """Merge chunks into rows Args: @@ -193,7 +195,7 @@ def merge_rows( while True: try: c = it.__next__() - except StopIteration: + except CrossSync._Sync_Impl.StopIteration: return row_key = c.row_key if not row_key: @@ -273,7 +275,7 @@ def merge_rows( ): raise InvalidChunk("reset row with data") continue - except StopIteration: + except CrossSync._Sync_Impl.StopIteration: raise InvalidChunk("premature end of stream") @staticmethod diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 97d5b2e7b..3fd5221a6 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -14,7 +14,7 @@ # from __future__ import annotations -from typing import TypeVar, Any, Awaitable, Callable, Coroutine, Sequence, Union, TYPE_CHECKING +from typing import TypeVar, Any, Awaitable, Callable, Coroutine, Sequence, Union, AsyncIterable, AsyncIterator, AsyncGenerator, Iterable, Iterator, Generator, TYPE_CHECKING import asyncio import sys @@ -43,7 +43,12 @@ class CrossSync: Task: TypeAlias = asyncio.Task Event: TypeAlias = asyncio.Event Semaphore: TypeAlias = asyncio.Semaphore + StopIteration: TypeAlias = StopAsyncIteration + # type annotations Awaitable: TypeAlias = Awaitable + Iterable: TypeAlias = AsyncIterable + Iterator: TypeAlias = AsyncIterator + Generator: TypeAlias = AsyncGenerator generated_replacements: dict[type, str] = {} @@ -202,7 +207,12 @@ class _Sync_Impl: Task: TypeAlias = concurrent.futures.Future Event: TypeAlias = threading.Event Semaphore: TypeAlias = threading.Semaphore + StopIteration: TypeAlias = StopAsyncIteration + # type annotations Awaitable: TypeAlias = Union[T] + Iterable: TypeAlias = Iterable + Iterator: TypeAlias = Iterator + Generator: TypeAlias = Generator generated_replacements: dict[type, str] = {} diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 524dab78d..e2bd5f370 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -32,6 +32,10 @@ from google.cloud.bigtable.data.mutations import Mutation from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if CrossSync._Sync_Impl.is_async: + pass +else: + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -39,7 +43,6 @@ pass else: from google.cloud.bigtable.data._sync.client import Table - from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation class MutationsBatcher: From 79f7b747a8b8568cc9a9c525ae6048d4f0f9985d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 13:19:37 -0600 Subject: [PATCH 122/360] got test conversion working --- .../cloud/bigtable/data/_sync/cross_sync.py | 16 +++-- .../data/_async/test_read_rows_acceptance.py | 27 ++++++- tests/unit/data/_sync/test__mutate_rows.py | 8 +-- tests/unit/data/_sync/test__read_rows.py | 22 ++---- tests/unit/data/_sync/test_client.py | 65 +++++++---------- .../unit/data/_sync/test_mutations_batcher.py | 70 +++++++------------ 6 files changed, 97 insertions(+), 111 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 3fd5221a6..a3dcb2e54 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -301,7 +301,7 @@ def yield_to_event_loop() -> None: class CrossSyncArtifact: file_path: str imports: list[ast.Import | ast.ImportFrom] = field(default_factory=list) - converted_classes: dict[type, ast.ClassDef] = field(default_factory=dict) + converted_classes: dict[str, ast.ClassDef] = field(default_factory=dict) mypy_ignores: list[str] = field(default_factory=list) _instances: ClassVar[dict[str, CrossSyncArtifact]] = {} @@ -344,7 +344,7 @@ def get_for_path(cls, path: str) -> CrossSyncArtifact: return cls._instances[path] def add_class(self, cls): - if cls in self.converted_classes: + if cls.cross_sync_class_name in self.converted_classes: return crosssync_converter = transformers.SymbolReplacer({"CrossSync": "CrossSync._Sync_Impl"}) # convert class @@ -360,7 +360,7 @@ def add_class(self, cls): converted = transformers.SymbolReplacer(cls.cross_sync_replace_symbols).visit(converted) converted = crosssync_converter.visit(converted) converted = transformers.HandleCrossSyncDecorators().visit(converted) - self.converted_classes[cls] = converted + self.converted_classes[cls.cross_sync_class_name] = converted # add imports for added class if required if cls.include_file_imports and not self.imports: with open(inspect.getfile(cls)) as f: @@ -382,7 +382,13 @@ def add_class(self, cls): import autoflake # find all cross_sync decorated classes search_root = sys.argv[1] - found_files = [path.replace("/", ".")[:-3] for path in glob.glob(search_root + "/**/*.py", recursive=True)] + # change directory to the root of the project + orig_dir = os.getcwd() + os.chdir(search_root) + # find all python files rooted here + found_files = [path.replace("/", ".")[:-3] for path in glob.glob("**/*.py", recursive=True)] + # add to path + sys.path.append(".") found_classes = list(itertools.chain.from_iterable([inspect.getmembers(importlib.import_module(path), inspect.isclass) for path in found_files])) file_obj_set = set() for name, cls in found_classes: @@ -390,6 +396,8 @@ def add_class(self, cls): file_obj = CrossSyncArtifact.get_for_path(cls.cross_sync_file_path) file_obj.add_class(cls) file_obj_set.add(file_obj) + # write out the files + os.chdir(orig_dir) for file_obj in file_obj_set: with open(file_obj.file_path, "w") as f: f.write(file_obj.render()) diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 7434e20af..1779ae2ac 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -17,6 +17,7 @@ import warnings import pytest import mock +import proto from itertools import zip_longest @@ -25,11 +26,33 @@ from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.row import Row -from tests.unit.v2_client.test_row_merger import ReadRowsTest, TestFile - from google.cloud.bigtable.data._sync.cross_sync import CrossSync +# TODO: autogenerate protos from +# https://github.com/googleapis/conformance-tests/blob/main/bigtable/v2/proto/google/cloud/conformance/bigtable/v2/tests.proto +class ReadRowsTest(proto.Message): + class Result(proto.Message): + row_key = proto.Field(proto.STRING, number=1) + family_name = proto.Field(proto.STRING, number=2) + qualifier = proto.Field(proto.STRING, number=3) + timestamp_micros = proto.Field(proto.INT64, number=4) + value = proto.Field(proto.STRING, number=5) + label = proto.Field(proto.STRING, number=6) + error = proto.Field(proto.BOOL, number=7) + + description = proto.Field(proto.STRING, number=1) + chunks = proto.RepeatedField( + proto.MESSAGE, number=2, message=ReadRowsResponse.CellChunk + ) + results = proto.RepeatedField(proto.MESSAGE, number=3, message=Result) + + +class TestFile(proto.Message): + __test__ = False + read_rows_tests = proto.RepeatedField(proto.MESSAGE, number=1, message=ReadRowsTest) + + @CrossSync.sync_output( "tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", ) diff --git a/tests/unit/data/_sync/test__mutate_rows.py b/tests/unit/data/_sync/test__mutate_rows.py index f4fc2d279..b4974b690 100644 --- a/tests/unit/data/_sync/test__mutate_rows.py +++ b/tests/unit/data/_sync/test__mutate_rows.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# This file is automatically generated by sync_surface_generator.py. Do not edit. - - +# This file is automatically generated by CrossSync. Do not edit manually. import pytest try: @@ -24,11 +22,11 @@ import mock from mock import AsyncMock +from google.cloud.bigtable_v2.types import MutateRowsResponse +from google.rpc import status_pb2 from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import Forbidden from google.cloud.bigtable.data._sync.cross_sync import CrossSync -from google.cloud.bigtable_v2.types import MutateRowsResponse -from google.rpc import status_pb2 class TestMutateRowsOperation: diff --git a/tests/unit/data/_sync/test__read_rows.py b/tests/unit/data/_sync/test__read_rows.py index 78b61cebc..bdf0a42ce 100644 --- a/tests/unit/data/_sync/test__read_rows.py +++ b/tests/unit/data/_sync/test__read_rows.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# This file is automatically generated by sync_surface_generator.py. Do not edit. - - +# This file is automatically generated by CrossSync. Do not edit manually. import pytest try: @@ -230,13 +228,11 @@ def test_revise_to_empty_rowset(self): ], ) def test_revise_limit(self, start_limit, emit_num, expected_limit): - """ - revise_limit should revise the request's limit field + """revise_limit should revise the request's limit field - if limit is 0 (unlimited), it should never be revised - if start_limit-emit_num == 0, the request should end early - if the number emitted exceeds the new limit, an exception should - should be raised (tested in test_revise_limit_over_limit) - """ + should be raised (tested in test_revise_limit_over_limit)""" from google.cloud.bigtable.data import ReadRowsQuery from google.cloud.bigtable_v2.types import ReadRowsResponse @@ -269,10 +265,8 @@ def mock_stream(): @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) def test_revise_limit_over_limit(self, start_limit, emit_num): - """ - Should raise runtime error if we get in state where emit_num > start_num - (unless start_num == 0, which represents unlimited) - """ + """Should raise runtime error if we get in state where emit_num > start_num + (unless start_num == 0, which represents unlimited)""" from google.cloud.bigtable.data import ReadRowsQuery from google.cloud.bigtable_v2.types import ReadRowsResponse from google.cloud.bigtable.data.exceptions import InvalidChunk @@ -306,10 +300,8 @@ def mock_stream(): assert "emit count exceeds row limit" in str(e.value) def test_aclose(self): - """ - should be able to close a stream safely with aclose. - Closed generators should raise StopAsyncIteration on next yield - """ + """should be able to close a stream safely with aclose. + Closed generators should raise StopAsyncIteration on next yield""" def mock_stream(): while True: diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 4140cbe9c..de4e069d6 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# This file is automatically generated by sync_surface_generator.py. Do not edit. - - +# This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -import asyncio import grpc -import pytest +import asyncio import re try: @@ -28,17 +25,19 @@ import mock from mock import AsyncMock -from google.api_core import exceptions as core_exceptions -from google.auth.credentials import AnonymousCredentials -from google.cloud.bigtable.data import TABLE_DEFAULT +import sys +import pytest from google.cloud.bigtable.data import mutations -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.auth.credentials import AnonymousCredentials +from google.cloud.bigtable_v2.types import ReadRowsResponse +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.api_core import exceptions as core_exceptions from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete -from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule +from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule -from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery -from google.cloud.bigtable_v2.types import ReadRowsResponse +from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule +from google.cloud.bigtable.data._sync.cross_sync import CrossSync if CrossSync._Sync_Impl.is_async: pass @@ -706,11 +705,9 @@ def test__remove_instance_registration(self): client.close() def test__multiple_table_registration(self): - """ - registering with multiple tables with the same key should + """registering with multiple tables with the same key should add multiple owners to instance_owners, but only keep one copy - of shared key in active_instances - """ + of shared key in active_instances""" from google.cloud.bigtable.data._helpers import _WarmedInstanceKey with self._make_client(project="project-id") as client: @@ -756,10 +753,8 @@ def test__multiple_table_registration(self): assert len(client._instance_owners[instance_1_key]) == 0 def test__multiple_instance_registration(self): - """ - registering with multiple instance keys should update the key - in instance_owners and active_instances - """ + """registering with multiple instance keys should update the key + in instance_owners and active_instances""" from google.cloud.bigtable.data._helpers import _WarmedInstanceKey with self._make_client(project="project-id") as client: @@ -1881,13 +1876,11 @@ def test_read_rows_timeout(self, operation_timeout): [(0.05, 0.08, 2), (0.05, 0.14, 3), (0.05, 0.24, 5)], ) def test_read_rows_attempt_timeout(self, per_request_t, operation_t, expected_num): - """ - Ensures that the attempt_timeout is respected and that the number of + """Ensures that the attempt_timeout is respected and that the number of requests is as expected. operation_timeout does not cancel the request, so we expect the number of - requests to be the ceiling of operation_timeout / attempt_timeout. - """ + requests to be the ceiling of operation_timeout / attempt_timeout.""" from google.cloud.bigtable.data.exceptions import RetryExceptionGroup expected_last_timeout = operation_t - (expected_num - 1) * per_request_t @@ -2260,11 +2253,9 @@ def mock_call(*args, **kwargs): assert call_time < 0.2 def test_read_rows_sharded_concurrency_limit(self): - """ - Only 10 queries should be processed concurrently. Others should be queued + """Only 10 queries should be processed concurrently. Others should be queued - Should start a new query as soon as previous finishes - """ + Should start a new query as soon as previous finishes""" from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT assert _CONCURRENCY_LIMIT == 10 @@ -2299,10 +2290,8 @@ def mock_call(*args, **kwargs): assert rpc_start_list[idx] - i * increment_time < eps def test_read_rows_sharded_expirary(self): - """ - If the operation times out before all shards complete, should raise - a ShardedReadRowsExceptionGroup - """ + """If the operation times out before all shards complete, should raise + a ShardedReadRowsExceptionGroup""" from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.api_core.exceptions import DeadlineExceeded @@ -2335,11 +2324,9 @@ def mock_call(*args, **kwargs): assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT def test_read_rows_sharded_negative_batch_timeout(self): - """ - try to run with batch that starts after operation timeout + """try to run with batch that starts after operation timeout - They should raise DeadlineExceeded errors - """ + They should raise DeadlineExceeded errors""" from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.api_core.exceptions import DeadlineExceeded @@ -2653,10 +2640,8 @@ def test_customizable_retryable_errors( is_stream, extra_retryables, ): - """ - Test that retryable functions support user-configurable arguments, and that the configured retryables are passed - down to the gapic layer. - """ + """Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer.""" retry_fn = "retry_target" if is_stream: retry_fn += "_stream" diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py index afdf9f905..4064fd3e8 100644 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# This file is automatically generated by sync_surface_generator.py. Do not edit. - - -import asyncio +# This file is automatically generated by CrossSync. Do not edit manually. import pytest +import asyncio import time try: @@ -25,12 +23,11 @@ except ImportError: import mock from mock import AsyncMock - -from google.cloud.bigtable.data import TABLE_DEFAULT -from google.cloud.bigtable.data import TableAsync -from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete import google.api_core.exceptions as core_exceptions import google.api_core.retry +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data import TABLE_DEFAULT +from google.cloud.bigtable.data import TableAsync class TestMutationsBatcher: @@ -199,11 +196,9 @@ def test_ctor_invalid_values(self): assert "attempt_timeout must be greater than 0" in str(e.value) def test_default_argument_consistency(self): - """ - We supply default arguments in MutationsBatcherAsync.__init__, and in + """We supply default arguments in MutationsBatcherAsync.__init__, and in table.mutations_batcher. Make sure any changes to defaults are applied to - both places - """ + both places""" import inspect get_batcher_signature = dict( @@ -304,10 +299,8 @@ def test_append_closed(self): instance.append(mock.Mock()) def test_append_wrong_mutation(self): - """ - Mutation objects should raise an exception. - Only support RowMutationEntry - """ + """Mutation objects should raise an exception. + Only support RowMutationEntry""" from google.cloud.bigtable.data.mutations import DeleteAllFromRow with self._make_one() as instance: @@ -338,10 +331,8 @@ def test_append_outside_flow_limits(self): instance._staged_entries = [] def test_append_flush_runs_after_limit_hit(self): - """ - If the user appends a bunch of entries above the flush limits back-to-back, - it should still flush in a single task - """ + """If the user appends a bunch of entries above the flush limits back-to-back, + it should still flush in a single task""" with mock.patch.object( self._get_target_class(), "_execute_mutate_rows" ) as op_mock: @@ -480,13 +471,11 @@ def test_schedule_flush_with_mutations(self): flush_mock.reset_mock() def test__flush_internal(self): - """ - _flush_internal should: - - await previous flush call - - delegate batching to _flow_control - - call _execute_mutate_rows on each batch - - update self.exceptions and self._entries_processed_since_last_raise - """ + """_flush_internal should: + - await previous flush call + - delegate batching to _flow_control + - call _execute_mutate_rows on each batch + - update self.exceptions and self._entries_processed_since_last_raise""" num_entries = 10 with self._make_one() as instance: with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: @@ -507,10 +496,8 @@ def gen(x): instance._newest_exceptions.clear() def test_flush_clears_job_list(self): - """ - a job should be added to _flush_jobs when _schedule_flush is called, - and removed when it completes - """ + """a job should be added to _flush_jobs when _schedule_flush is called, + and removed when it completes""" with self._make_one() as instance: with mock.patch.object( instance, "_flush_internal", AsyncMock() @@ -751,10 +738,8 @@ def test_atexit_registration(self): assert register_mock.call_count == 1 def test_timeout_args_passed(self): - """ - batch_operation_timeout and batch_attempt_timeout should be used - in api calls - """ + """batch_operation_timeout and batch_attempt_timeout should be used + in api calls""" if self.is_async(): mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" else: @@ -798,8 +783,7 @@ def test_timeout_args_passed(self): ], ) def test__add_exceptions(self, limit, in_e, start_e, end_e): - """ - Test that the _add_exceptions function properly updates the + """Test that the _add_exceptions function properly updates the _oldest_exceptions and _newest_exceptions lists Args: - limit: the _exception_list_limit representing the max size of either list @@ -855,10 +839,8 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): ], ) def test_customizable_retryable_errors(self, input_retryables, expected_retryables): - """ - Test that retryable functions support user-configurable arguments, and that the configured retryables are passed - down to the gapic layer. - """ + """Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer.""" retryn_fn = ( "google.cloud.bigtable.data._sync.cross_sync.CrossSync.retry_target" if "Async" in self._get_target_class().__name__ @@ -1087,10 +1069,8 @@ def test_add_to_flow(self, mutations, count_cap, size_cap, expected_results): def test_add_to_flow_max_mutation_limits( self, mutations, max_limit, expected_results ): - """ - Test flow control running up against the max API limit - Should submit request early, even if the flow control has room for more - """ + """Test flow control running up against the max API limit + Should submit request early, even if the flow control has room for more""" async_patch = mock.patch( "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", max_limit, From 2652158855da3d49d94a9c68b8f08c6b47de290c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 13:29:12 -0600 Subject: [PATCH 123/360] keep try import blocks --- google/cloud/bigtable/data/_sync/cross_sync.py | 4 ++-- tests/unit/data/_sync/test__mutate_rows.py | 11 +++++------ tests/unit/data/_sync/test_client.py | 14 ++++++-------- tests/unit/data/_sync/test_mutations_batcher.py | 10 +++++----- tests/unit/data/_sync/test_read_rows_acceptance.py | 14 +++++--------- 5 files changed, 23 insertions(+), 30 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index a3dcb2e54..ebe1e1f0a 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -300,7 +300,7 @@ def yield_to_event_loop() -> None: @dataclass class CrossSyncArtifact: file_path: str - imports: list[ast.Import | ast.ImportFrom] = field(default_factory=list) + imports: list[ast.Import | ast.ImportFrom | ast.If | ast.Try] = field(default_factory=list) converted_classes: dict[str, ast.ClassDef] = field(default_factory=dict) mypy_ignores: list[str] = field(default_factory=list) _instances: ClassVar[dict[str, CrossSyncArtifact]] = {} @@ -366,7 +366,7 @@ def add_class(self, cls): with open(inspect.getfile(cls)) as f: full_ast = ast.parse(f.read()) for node in full_ast.body: - if isinstance(node, (ast.Import, ast.ImportFrom, ast.If)): + if isinstance(node, (ast.Import, ast.ImportFrom, ast.If, ast.Try)): self.imports.append(crosssync_converter.visit(node)) # add mypy ignore if required if cls.cross_sync_mypy_ignore: diff --git a/tests/unit/data/_sync/test__mutate_rows.py b/tests/unit/data/_sync/test__mutate_rows.py index b4974b690..bcdd1103c 100644 --- a/tests/unit/data/_sync/test__mutate_rows.py +++ b/tests/unit/data/_sync/test__mutate_rows.py @@ -14,6 +14,11 @@ # # This file is automatically generated by CrossSync. Do not edit manually. import pytest +from google.cloud.bigtable_v2.types import MutateRowsResponse +from google.rpc import status_pb2 +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import Forbidden +from google.cloud.bigtable.data._sync.cross_sync import CrossSync try: from unittest import mock @@ -22,12 +27,6 @@ import mock from mock import AsyncMock -from google.cloud.bigtable_v2.types import MutateRowsResponse -from google.rpc import status_pb2 -from google.api_core.exceptions import DeadlineExceeded -from google.api_core.exceptions import Forbidden -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - class TestMutateRowsOperation: def _target_class(self): diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index de4e069d6..9df88dad3 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -17,14 +17,6 @@ import grpc import asyncio import re - -try: - from unittest import mock - from unittest.mock import AsyncMock -except ImportError: - import mock - from mock import AsyncMock - import sys import pytest from google.cloud.bigtable.data import mutations @@ -39,6 +31,12 @@ from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule from google.cloud.bigtable.data._sync.cross_sync import CrossSync +try: + from unittest import mock + from unittest.mock import AsyncMock +except ImportError: + import mock + from mock import AsyncMock if CrossSync._Sync_Impl.is_async: pass else: diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py index 4064fd3e8..9affe0d7a 100644 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -16,6 +16,11 @@ import pytest import asyncio import time +import google.api_core.exceptions as core_exceptions +import google.api_core.retry +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data import TABLE_DEFAULT +from google.cloud.bigtable.data import TableAsync try: from unittest import mock @@ -23,11 +28,6 @@ except ImportError: import mock from mock import AsyncMock -import google.api_core.exceptions as core_exceptions -import google.api_core.retry -from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete -from google.cloud.bigtable.data import TABLE_DEFAULT -from google.cloud.bigtable.data import TableAsync class TestMutationsBatcher: diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync/test_read_rows_acceptance.py index 1f423232c..25e59f53f 100644 --- a/tests/unit/data/_sync/test_read_rows_acceptance.py +++ b/tests/unit/data/_sync/test_read_rows_acceptance.py @@ -12,20 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# This file is automatically generated by sync_surface_generator.py. Do not edit. - - +# This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from itertools import zip_longest -from tests.unit.v2_client.test_row_merger import ReadRowsTest, TestFile -import mock import os -import pytest import warnings - +import pytest +import mock +from itertools import zip_longest +from google.cloud.bigtable_v2 import ReadRowsResponse from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.row import Row -from google.cloud.bigtable_v2 import ReadRowsResponse class TestReadRowsAcceptance: From 5b317792741803abe495a7f70c15ef84f55e13b9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 16:49:01 -0600 Subject: [PATCH 124/360] refactored into node transformer --- .../cloud/bigtable/data/_sync/cross_sync.py | 124 +--------- .../bigtable/data/_sync/mutations_batcher.py | 218 +++++++++--------- .../cloud/bigtable/data/_sync/transformers.py | 142 ++++++++++++ 3 files changed, 262 insertions(+), 222 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index ebe1e1f0a..716160631 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -67,24 +67,13 @@ def drop_method(func): def sync_output( cls, sync_path: str, + *, replace_symbols: dict["str", "str" | None ] | None = None, mypy_ignore: list[str] | None = None, - include_file_imports: bool = True, + include_file_imports: bool = False, ): - replace_symbols = replace_symbols or {} - mypy_ignore = mypy_ignore or [] - include_file_imports = include_file_imports - # return the async class unchanged def decorator(async_cls): - cls.generated_replacements[async_cls] = sync_path - async_cls.cross_sync_enabled = True - async_cls.cross_sync_import_path = sync_path - async_cls.cross_sync_class_name = sync_path.rsplit(".", 1)[-1] - async_cls.cross_sync_file_path = "/".join(sync_path.split(".")[:-1]) + ".py" - async_cls.cross_sync_replace_symbols = replace_symbols - async_cls.cross_sync_mypy_ignore = mypy_ignore - async_cls.include_file_imports = include_file_imports return async_cls return decorator @@ -291,87 +280,8 @@ def create_task( def yield_to_event_loop() -> None: pass -from dataclasses import dataclass, field -from typing import ClassVar -import ast - from google.cloud.bigtable.data._sync import transformers -@dataclass -class CrossSyncArtifact: - file_path: str - imports: list[ast.Import | ast.ImportFrom | ast.If | ast.Try] = field(default_factory=list) - converted_classes: dict[str, ast.ClassDef] = field(default_factory=dict) - mypy_ignores: list[str] = field(default_factory=list) - _instances: ClassVar[dict[str, CrossSyncArtifact]] = {} - - def __hash__(self): - return hash(self.file_path) - - def render(self, with_black=True) -> str: - full_str = ( - "# Copyright 2024 Google LLC\n" - "#\n" - '# Licensed under the Apache License, Version 2.0 (the "License");\n' - '# you may not use this file except in compliance with the License.\n' - '# You may obtain a copy of the License at\n' - '#\n' - '# http://www.apache.org/licenses/LICENSE-2.0\n' - '#\n' - '# Unless required by applicable law or agreed to in writing, software\n' - '# distributed under the License is distributed on an "AS IS" BASIS,\n' - '# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n' - '# See the License for the specific language governing permissions and\n' - '# limitations under the License.\n' - '#\n' - '# This file is automatically generated by CrossSync. Do not edit manually.\n' - ) - if self.mypy_ignores: - full_str += f'\n# mypy: disable-error-code="{",".join(self.mypy_ignores)}"\n\n' - full_str += "\n".join([ast.unparse(node) for node in self.imports]) - full_str += "\n\n" - full_str += "\n".join([ast.unparse(node) for node in self.converted_classes.values()]) - if with_black: - cleaned = black.format_str(autoflake.fix_code(full_str, remove_all_unused_imports=True), mode=black.FileMode()) - return cleaned - else: - return full_str - - @classmethod - def get_for_path(cls, path: str) -> CrossSyncArtifact: - if path not in cls._instances: - cls._instances[path] = CrossSyncArtifact(path) - return cls._instances[path] - - def add_class(self, cls): - if cls.cross_sync_class_name in self.converted_classes: - return - crosssync_converter = transformers.SymbolReplacer({"CrossSync": "CrossSync._Sync_Impl"}) - # convert class - cls_node = ast.parse(inspect.getsource(cls)).body[0] - # update name - cls_node.name = cls.cross_sync_class_name - # remove cross_sync decorator - if hasattr(cls_node, "decorator_list"): - cls_node.decorator_list = [d for d in cls_node.decorator_list if not isinstance(d, ast.Call) or not isinstance(d.func, ast.Attribute) or not isinstance(d.func.value, ast.Name) or d.func.value.id != "CrossSync"] - # do ast transformations - converted = transformers.AsyncToSync().visit(cls_node) - if cls.cross_sync_replace_symbols: - converted = transformers.SymbolReplacer(cls.cross_sync_replace_symbols).visit(converted) - converted = crosssync_converter.visit(converted) - converted = transformers.HandleCrossSyncDecorators().visit(converted) - self.converted_classes[cls.cross_sync_class_name] = converted - # add imports for added class if required - if cls.include_file_imports and not self.imports: - with open(inspect.getfile(cls)) as f: - full_ast = ast.parse(f.read()) - for node in full_ast.body: - if isinstance(node, (ast.Import, ast.ImportFrom, ast.If, ast.Try)): - self.imports.append(crosssync_converter.visit(node)) - # add mypy ignore if required - if cls.cross_sync_mypy_ignore: - self.mypy_ignores.extend(cls.cross_sync_mypy_ignore) - if __name__ == "__main__": import os import glob @@ -382,24 +292,12 @@ def add_class(self, cls): import autoflake # find all cross_sync decorated classes search_root = sys.argv[1] - # change directory to the root of the project - orig_dir = os.getcwd() - os.chdir(search_root) - # find all python files rooted here - found_files = [path.replace("/", ".")[:-3] for path in glob.glob("**/*.py", recursive=True)] - # add to path - sys.path.append(".") - found_classes = list(itertools.chain.from_iterable([inspect.getmembers(importlib.import_module(path), inspect.isclass) for path in found_files])) - file_obj_set = set() - for name, cls in found_classes: - if hasattr(cls, "cross_sync_enabled"): - file_obj = CrossSyncArtifact.get_for_path(cls.cross_sync_file_path) - file_obj.add_class(cls) - file_obj_set.add(file_obj) - # write out the files - os.chdir(orig_dir) - for file_obj in file_obj_set: - with open(file_obj.file_path, "w") as f: - f.write(file_obj.render()) - - + # cross_sync_classes = load_classes_from_dir(search_root)\ + files = glob.glob(search_root + "/**/*.py", recursive=True) + artifacts = set() + for file in files: + converter = transformers.CrossSyncClassParser(file) + converter.convert_file(artifacts) + print(artifacts) + for artifact in artifacts: + artifact.render(save_to_disk=True) diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index e2bd5f370..006982c1f 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -45,6 +45,115 @@ from google.cloud.bigtable.data._sync.client import Table +class _FlowControl: + """ + Manages flow control for batched mutations. Mutations are registered against + the FlowControl object before being sent, which will block if size or count + limits have reached capacity. As mutations completed, they are removed from + the FlowControl object, which will notify any blocked requests that there + is additional capacity. + + Flow limits are not hard limits. If a single mutation exceeds the configured + limits, it will be allowed as a single batch when the capacity is available. + + Args: + max_mutation_count: maximum number of mutations to send in a single rpc. + This corresponds to individual mutations in a single RowMutationEntry. + max_mutation_bytes: maximum number of bytes to send in a single rpc. + Raises: + ValueError: if max_mutation_count or max_mutation_bytes is less than 0 + """ + + def __init__(self, max_mutation_count: int, max_mutation_bytes: int): + self._max_mutation_count = max_mutation_count + self._max_mutation_bytes = max_mutation_bytes + if self._max_mutation_count < 1: + raise ValueError("max_mutation_count must be greater than 0") + if self._max_mutation_bytes < 1: + raise ValueError("max_mutation_bytes must be greater than 0") + self._capacity_condition = CrossSync._Sync_Impl.Condition() + self._in_flight_mutation_count = 0 + self._in_flight_mutation_bytes = 0 + + def _has_capacity(self, additional_count: int, additional_size: int) -> bool: + """Checks if there is capacity to send a new entry with the given size and count + + FlowControl limits are not hard limits. If a single mutation exceeds + the configured flow limits, it will be sent in a single batch when + previous batches have completed. + + Args: + additional_count: number of mutations in the pending entry + additional_size: size of the pending entry + Returns: + bool: True if there is capacity to send the pending entry, False otherwise + """ + acceptable_size = max(self._max_mutation_bytes, additional_size) + acceptable_count = max(self._max_mutation_count, additional_count) + new_size = self._in_flight_mutation_bytes + additional_size + new_count = self._in_flight_mutation_count + additional_count + return new_size <= acceptable_size and new_count <= acceptable_count + + def remove_from_flow( + self, mutations: RowMutationEntry | list[RowMutationEntry] + ) -> None: + """Removes mutations from flow control. This method should be called once + for each mutation that was sent to add_to_flow, after the corresponding + operation is complete. + + Args: + mutations: mutation or list of mutations to remove from flow control""" + if not isinstance(mutations, list): + mutations = [mutations] + total_count = sum((len(entry.mutations) for entry in mutations)) + total_size = sum((entry.size() for entry in mutations)) + self._in_flight_mutation_count -= total_count + self._in_flight_mutation_bytes -= total_size + with self._capacity_condition: + self._capacity_condition.notify_all() + + def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): + """Generator function that registers mutations with flow control. As mutations + are accepted into the flow control, they are yielded back to the caller, + to be sent in a batch. If the flow control is at capacity, the generator + will block until there is capacity available. + + Args: + mutations: list mutations to break up into batches + Yields: + list[RowMutationEntry]: + list of mutations that have reserved space in the flow control. + Each batch contains at least one mutation.""" + if not isinstance(mutations, list): + mutations = [mutations] + start_idx = 0 + end_idx = 0 + while end_idx < len(mutations): + start_idx = end_idx + batch_mutation_count = 0 + with self._capacity_condition: + while end_idx < len(mutations): + next_entry = mutations[end_idx] + next_size = next_entry.size() + next_count = len(next_entry.mutations) + if ( + self._has_capacity(next_count, next_size) + and batch_mutation_count + next_count + <= _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + ): + end_idx += 1 + batch_mutation_count += next_count + self._in_flight_mutation_bytes += next_size + self._in_flight_mutation_count += next_count + elif start_idx != end_idx: + break + else: + self._capacity_condition.wait_for( + lambda: self._has_capacity(next_count, next_size) + ) + yield mutations[start_idx:end_idx] + + class MutationsBatcher: """ Allows users to send batches using context manager API: @@ -340,112 +449,3 @@ def _wait_for_batch_results( except Exception as e: exceptions.append(e) return exceptions - - -class _FlowControl: - """ - Manages flow control for batched mutations. Mutations are registered against - the FlowControl object before being sent, which will block if size or count - limits have reached capacity. As mutations completed, they are removed from - the FlowControl object, which will notify any blocked requests that there - is additional capacity. - - Flow limits are not hard limits. If a single mutation exceeds the configured - limits, it will be allowed as a single batch when the capacity is available. - - Args: - max_mutation_count: maximum number of mutations to send in a single rpc. - This corresponds to individual mutations in a single RowMutationEntry. - max_mutation_bytes: maximum number of bytes to send in a single rpc. - Raises: - ValueError: if max_mutation_count or max_mutation_bytes is less than 0 - """ - - def __init__(self, max_mutation_count: int, max_mutation_bytes: int): - self._max_mutation_count = max_mutation_count - self._max_mutation_bytes = max_mutation_bytes - if self._max_mutation_count < 1: - raise ValueError("max_mutation_count must be greater than 0") - if self._max_mutation_bytes < 1: - raise ValueError("max_mutation_bytes must be greater than 0") - self._capacity_condition = CrossSync._Sync_Impl.Condition() - self._in_flight_mutation_count = 0 - self._in_flight_mutation_bytes = 0 - - def _has_capacity(self, additional_count: int, additional_size: int) -> bool: - """Checks if there is capacity to send a new entry with the given size and count - - FlowControl limits are not hard limits. If a single mutation exceeds - the configured flow limits, it will be sent in a single batch when - previous batches have completed. - - Args: - additional_count: number of mutations in the pending entry - additional_size: size of the pending entry - Returns: - bool: True if there is capacity to send the pending entry, False otherwise - """ - acceptable_size = max(self._max_mutation_bytes, additional_size) - acceptable_count = max(self._max_mutation_count, additional_count) - new_size = self._in_flight_mutation_bytes + additional_size - new_count = self._in_flight_mutation_count + additional_count - return new_size <= acceptable_size and new_count <= acceptable_count - - def remove_from_flow( - self, mutations: RowMutationEntry | list[RowMutationEntry] - ) -> None: - """Removes mutations from flow control. This method should be called once - for each mutation that was sent to add_to_flow, after the corresponding - operation is complete. - - Args: - mutations: mutation or list of mutations to remove from flow control""" - if not isinstance(mutations, list): - mutations = [mutations] - total_count = sum((len(entry.mutations) for entry in mutations)) - total_size = sum((entry.size() for entry in mutations)) - self._in_flight_mutation_count -= total_count - self._in_flight_mutation_bytes -= total_size - with self._capacity_condition: - self._capacity_condition.notify_all() - - def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): - """Generator function that registers mutations with flow control. As mutations - are accepted into the flow control, they are yielded back to the caller, - to be sent in a batch. If the flow control is at capacity, the generator - will block until there is capacity available. - - Args: - mutations: list mutations to break up into batches - Yields: - list[RowMutationEntry]: - list of mutations that have reserved space in the flow control. - Each batch contains at least one mutation.""" - if not isinstance(mutations, list): - mutations = [mutations] - start_idx = 0 - end_idx = 0 - while end_idx < len(mutations): - start_idx = end_idx - batch_mutation_count = 0 - with self._capacity_condition: - while end_idx < len(mutations): - next_entry = mutations[end_idx] - next_size = next_entry.size() - next_count = len(next_entry.mutations) - if ( - self._has_capacity(next_count, next_size) - and batch_mutation_count + next_count - <= _MUTATE_ROWS_REQUEST_MUTATION_LIMIT - ): - end_idx += 1 - batch_mutation_count += next_count - self._in_flight_mutation_bytes += next_size - self._in_flight_mutation_count += next_count - elif start_idx != end_idx: - break - else: - self._capacity_condition.wait_for( - lambda: self._has_capacity(next_count, next_size) - ) - yield mutations[start_idx:end_idx] diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index fcf9092ea..272325b32 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -1,5 +1,7 @@ import ast +from dataclasses import dataclass, field + class SymbolReplacer(ast.NodeTransformer): def __init__(self, replacements): @@ -108,4 +110,144 @@ def visit_FunctionDef(self, node): node.decorator_list.append(decorator) return node +@dataclass +class CrossSyncFileArtifact: + """ + Used to track an output file location. Collects a number of converted classes, and then + writes them to disk + """ + file_path: str + imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = field(default_factory=list) + converted_classes: list[ast.ClassDef] = field(default_factory=list) + contained_classes: set[str] = field(default_factory=set) + mypy_ignore: list[str] = field(default_factory=list) + + def __hash__(self): + return hash(self.file_path) + + def __repr__(self): + return f"CrossSyncFileArtifact({self.file_path}, classes={[c.name for c in self.converted_classes]})" + + def render(self, with_black=True, save_to_disk=False) -> str: + full_str = ( + "# Copyright 2024 Google LLC\n" + "#\n" + '# Licensed under the Apache License, Version 2.0 (the "License");\n' + '# you may not use this file except in compliance with the License.\n' + '# You may obtain a copy of the License at\n' + '#\n' + '# http://www.apache.org/licenses/LICENSE-2.0\n' + '#\n' + '# Unless required by applicable law or agreed to in writing, software\n' + '# distributed under the License is distributed on an "AS IS" BASIS,\n' + '# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n' + '# See the License for the specific language governing permissions and\n' + '# limitations under the License.\n' + '#\n' + '# This file is automatically generated by CrossSync. Do not edit manually.\n' + ) + if self.mypy_ignore: + full_str += f'\n# mypy: disable-error-code="{",".join(self.mypy_ignore)}"\n\n' + full_str += "\n".join([ast.unparse(node) for node in self.imports]) + full_str += "\n\n" + full_str += "\n".join([ast.unparse(node) for node in self.converted_classes]) + if with_black: + import black + import autoflake + full_str = black.format_str(autoflake.fix_code(full_str, remove_all_unused_imports=True), mode=black.FileMode()) + if save_to_disk: + with open(self.file_path, "w") as f: + f.write(full_str) + return full_str + +class CrossSyncClassParser(ast.NodeTransformer): + + def __init__(self, file_path): + self.in_path = file_path + self._artifact_dict = {} + self.imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = [] + self.cross_sync_converter = SymbolReplacer({"CrossSync": "CrossSync._Sync_Impl"}) + + + def convert_file(self, artifacts:set[CrossSyncFileArtifact]|None=None) -> set[CrossSyncFileArtifact]: + """ + Called to run a file through the transformer. If any classes are marked with a CrossSync decorator, + they will be transformed and added to an artifact for the output file + """ + tree = ast.parse(open(self.in_path).read()) + self._artifact_dict = {f.file_path: f for f in artifacts or []} + self.imports = self._get_imports(tree) + self.visit(tree) + found = self._artifact_dict.values() + if artifacts is not None: + artifacts.update(found) + return found + + def visit_ClassDef(self, node): + """ + Called for each class in file. If class has a CrossSync decorator, it will be transformed + according to the decorator arguments + """ + for decorator in node.decorator_list: + if "CrossSync" in ast.dump(decorator): + kwargs = {kw.arg: self._convert_ast_to_py(kw.value) for kw in decorator.keywords} + # find the path to write the sync class to + sync_path = kwargs.pop("sync_path", None) + if not sync_path: + sync_path = decorator.args[0].s + out_file = "/".join(sync_path.rsplit(".")[:-1]) + ".py" + sync_cls_name = sync_path.rsplit(".", 1)[-1] + # find the artifact file for the save location + output_artifact = self._artifact_dict.get(out_file, CrossSyncFileArtifact(out_file)) + # write converted class details if not already present + if sync_cls_name not in output_artifact.contained_classes: + converted = self._transform_class(node, sync_cls_name, **kwargs) + output_artifact.converted_classes.append(converted) + # handle file-level mypy ignores + mypy_ignores = [s for s in kwargs.get("mypy_ignore", []) if s not in output_artifact.mypy_ignore] + output_artifact.mypy_ignore.extend(mypy_ignores) + # handle file-level imports + if not output_artifact.imports and kwargs.get("include_file_imports", True): + output_artifact.imports = self.imports + self._artifact_dict[out_file] = output_artifact + return node + + def _transform_class(self, cls_ast: ast.ClassDef, new_name:str, replace_symbols=None, **kwargs) -> ast.ClassDef: + """ + Transform async class into sync one, by running through a series of transformers + """ + # update name + cls_ast.name = new_name + # strip CrossSync decorators + if hasattr(cls_ast, "decorator_list"): + cls_ast.decorator_list = [d for d in cls_ast.decorator_list if "CrossSync" not in ast.dump(d)] + # convert class contents + cls_ast = AsyncToSync().visit(cls_ast) + cls_ast = self.cross_sync_converter.visit(cls_ast) + if replace_symbols: + cls_ast = SymbolReplacer(replace_symbols).visit(cls_ast) + cls_ast = HandleCrossSyncDecorators().visit(cls_ast) + return cls_ast + + def _get_imports(self, tree:ast.Module) -> list[ast.Import | ast.ImportFrom | ast.Try | ast.If]: + """ + Grab the imports from the top of the file + """ + imports = [] + for node in tree.body: + if isinstance(node, (ast.Import, ast.ImportFrom, ast.Try, ast.If)): + imports.append(self.cross_sync_converter.visit(node)) + return imports + + def _convert_ast_to_py(self, ast_node): + """ + Helper to convert ast primitives to python primitives. Used when unwrapping kwargs + """ + if isinstance(ast_node, ast.Constant): + return ast_node.value + if isinstance(ast_node, ast.List): + return [self._convert_ast_to_py(node) for node in ast_node.elts] + if isinstance(ast_node, ast.Dict): + return {self._convert_ast_to_py(k): self._convert_ast_to_py(v) for k, v in zip(ast_node.keys, ast_node.values)} + raise ValueError(f"Unsupported type {type(ast_node)}") From 53dbb779827a19ebc607510ed96ea25d7c8a1722 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 17:05:38 -0600 Subject: [PATCH 125/360] added pytest decorator --- .../cloud/bigtable/data/_sync/cross_sync.py | 6 +- .../cloud/bigtable/data/_sync/transformers.py | 10 +- tests/unit/data/_async/test__mutate_rows.py | 18 +- tests/unit/data/_async/test__read_rows.py | 8 +- tests/unit/data/_async/test_client.py | 178 +- .../data/_async/test_mutations_batcher.py | 74 +- .../data/_async/test_read_rows_acceptance.py | 20 +- tests/unit/data/_sync/test_client.py | 3159 ++++++++--------- .../unit/data/_sync/test_mutations_batcher.py | 476 +-- 9 files changed, 1979 insertions(+), 1970 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 716160631..b20214150 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -75,9 +75,13 @@ def sync_output( # return the async class unchanged def decorator(async_cls): return async_cls - return decorator + @staticmethod + def pytest(func): + import pytest + return pytest.mark.asyncio(func) + @staticmethod async def gather_partials( partial_list: Sequence[Callable[[], Awaitable[T]]], diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index 272325b32..d3673f1b8 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -96,8 +96,8 @@ def visit_FunctionDef(self, node): if hasattr(node, "decorator_list"): found_list, node.decorator_list = node.decorator_list, [] for decorator in found_list: - if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name) and "CrossSync" in decorator.func.value.id: - decorator_type = decorator.func.attr + if "CrossSync" in ast.dump(decorator): + decorator_type = decorator.func.attr if hasattr(decorator, "func") else decorator.attr if decorator_type == "convert": for subcommand in decorator.keywords: if subcommand.arg == "sync_name": @@ -105,6 +105,12 @@ def visit_FunctionDef(self, node): if subcommand.arg == "replace_symbols": replacements = {subcommand.value.keys[i].s: subcommand.value.values[i].s for i in range(len(subcommand.value.keys))} node = SymbolReplacer(replacements).visit(node) + elif decorator_type == "pytest": + pass + elif decorator_type == "drop_method": + return None + else: + raise ValueError(f"Unsupported CrossSync decorator: {decorator_type}") else: # add non-crosssync decorators back node.decorator_list.append(decorator) diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 03b7db3f4..55a6fdd40 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -162,7 +162,7 @@ def test_ctor_too_many_entries(self): ) assert "Found 100001" in str(e.value) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_rows_operation(self): """ Test successful case of mutate_rows_operation @@ -184,7 +184,7 @@ async def test_mutate_rows_operation(self): @pytest.mark.parametrize( "exc_type", [RuntimeError, ZeroDivisionError, Forbidden] ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_rows_attempt_exception(self, exc_type): """ exceptions raised from attempt should be raised in MutationsExceptionGroup @@ -212,7 +212,7 @@ async def test_mutate_rows_attempt_exception(self, exc_type): @pytest.mark.parametrize( "exc_type", [RuntimeError, ZeroDivisionError, Forbidden] ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_rows_exception(self, exc_type): """ exceptions raised from retryable should be raised in MutationsExceptionGroup @@ -250,7 +250,7 @@ async def test_mutate_rows_exception(self, exc_type): "exc_type", [DeadlineExceeded, RuntimeError], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): """ If an exception fails but eventually passes, it should not raise an exception @@ -279,7 +279,7 @@ async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): await instance.start() assert attempt_mock.call_count == num_retries + 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_rows_incomplete_ignored(self): """ MutateRowsIncomplete exceptions should not be added to error list @@ -310,7 +310,7 @@ async def test_mutate_rows_incomplete_ignored(self): assert len(found_exc.exceptions) == 1 assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) - @pytest.mark.asyncio + @CrossSync.pytest async def test_run_attempt_single_entry_success(self): """Test mutating a single entry""" mutation = self._make_mutation() @@ -328,7 +328,7 @@ async def test_run_attempt_single_entry_success(self): assert kwargs["timeout"] == expected_timeout assert kwargs["entries"] == [mutation._to_pb()] - @pytest.mark.asyncio + @CrossSync.pytest async def test_run_attempt_empty_request(self): """Calling with no mutations should result in no API calls""" mock_gapic_fn = self._make_mock_gapic([]) @@ -338,7 +338,7 @@ async def test_run_attempt_empty_request(self): await instance._run_attempt() assert mock_gapic_fn.call_count == 0 - @pytest.mark.asyncio + @CrossSync.pytest async def test_run_attempt_partial_success_retryable(self): """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete @@ -361,7 +361,7 @@ async def test_run_attempt_partial_success_retryable(self): assert instance.errors[1][0].grpc_status_code == 300 assert 2 not in instance.errors - @pytest.mark.asyncio + @CrossSync.pytest async def test_run_attempt_partial_success_non_retryable(self): """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" success_mutation = self._make_mutation() diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index abda3af05..130ecae7a 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -245,7 +245,7 @@ def test_revise_to_empty_rowset(self): (4, 2, 2), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_revise_limit(self, start_limit, emit_num, expected_limit): """ revise_limit should revise the request's limit field @@ -286,7 +286,7 @@ async def mock_stream(): assert instance._remaining_count == expected_limit @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_revise_limit_over_limit(self, start_limit, emit_num): """ Should raise runtime error if we get in state where emit_num > start_num @@ -325,7 +325,7 @@ async def mock_stream(): pass assert "emit count exceeds row limit" in str(e.value) - @pytest.mark.asyncio + @CrossSync.pytest async def test_aclose(self): """ should be able to close a stream safely with aclose. @@ -354,7 +354,7 @@ async def mock_stream(): with pytest.raises(StopAsyncIteration): await wrapped_gen.__anext__() - @pytest.mark.asyncio + @CrossSync.pytest async def test_retryable_ignore_repeated_rows(self): """ Duplicate rows should cause an invalid chunk error diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index b7a5cdba0..5f6af8be8 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -104,7 +104,7 @@ def _make_client(cls, *args, use_emulator=True, **kwargs): with mock.patch.dict(os.environ, env_mask): return cls._get_target_class()(*args, **kwargs) - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor(self): expected_project = "project-id" expected_pool_size = 11 @@ -123,7 +123,7 @@ async def test_ctor(self): assert client.transport._credentials == expected_credentials await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor_super_inits(self): from google.cloud.client import ClientWithProject from google.api_core import client_options as client_options_lib @@ -165,7 +165,7 @@ async def test_ctor_super_inits(self): assert kwargs["credentials"] == credentials assert kwargs["client_options"] == options_parsed - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor_dict_options(self): from google.api_core.client_options import ClientOptions @@ -189,7 +189,7 @@ async def test_ctor_dict_options(self): start_background_refresh.assert_called_once() await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_veneer_grpc_headers(self): client_component = "data-async" if CrossSync.is_async else "data" VENEER_HEADER_REGEX = re.compile( @@ -220,7 +220,7 @@ async def test_veneer_grpc_headers(self): ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_channel_pool_creation(self): pool_size = 14 with mock.patch.object( @@ -236,7 +236,7 @@ async def test_channel_pool_creation(self): assert len(pool_list) == len(pool_set) await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_channel_pool_rotation(self): pool_size = 7 @@ -260,7 +260,7 @@ async def test_channel_pool_rotation(self): unary_unary.reset_mock() await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_channel_pool_replace(self): import time @@ -300,7 +300,7 @@ def test__start_background_channel_refresh_sync(self): with pytest.raises(RuntimeError): client._start_background_channel_refresh() - @pytest.mark.asyncio + @CrossSync.pytest async def test__start_background_channel_refresh_tasks_exist(self): # if tasks exist, should do nothing client = self._make_client(project="project-id", use_emulator=False) @@ -310,7 +310,7 @@ async def test__start_background_channel_refresh_tasks_exist(self): create_task.assert_not_called() await client.close() - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize("pool_size", [1, 3, 7]) async def test__start_background_channel_refresh(self, pool_size): import concurrent.futures @@ -337,7 +337,7 @@ async def test__start_background_channel_refresh(self, pool_size): await client.close() @CrossSync.drop_method - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.skipif( sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" ) @@ -353,7 +353,7 @@ async def test__start_background_channel_refresh_tasks_names(self): assert "BigtableDataClientAsync channel refresh " in name await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__ping_and_warm_instances(self): """ test ping and warm with mocked asyncio.gather @@ -411,7 +411,7 @@ async def test__ping_and_warm_instances(self): == f"name={expected_instance}&app_profile_id={expected_app_profile}" ) - @pytest.mark.asyncio + @CrossSync.pytest async def test__ping_and_warm_single_instance(self): """ should be able to call ping and warm with single instance @@ -447,7 +447,7 @@ async def test__ping_and_warm_single_instance(self): metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" ) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "refresh_interval, wait_time, expected_sleep", [ @@ -483,7 +483,7 @@ async def test__manage_channel_first_sleep( ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__manage_channel_ping_and_warm(self): """ _manage channel should call ping and warm internally @@ -530,7 +530,7 @@ async def test__manage_channel_ping_and_warm(self): pass ping_and_warm.assert_called_once_with(new_channel) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "refresh_interval, num_cycles, expected_sleep", [ @@ -580,7 +580,7 @@ async def test__manage_channel_sleeps( ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__manage_channel_random(self): import random import threading @@ -611,7 +611,7 @@ async def test__manage_channel_random(self): assert found_min == min_val assert found_max == max_val - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) async def test__manage_channel_refresh(self, num_cycles): # make sure that channels are properly refreshed @@ -663,7 +663,7 @@ async def test__manage_channel_refresh(self, num_cycles): assert kwargs["new_channel"] == new_channel await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__register_instance(self): """ test instance registration @@ -734,7 +734,7 @@ async def test__register_instance(self): ] ) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "insert_instances,expected_active,expected_owner_keys", [ @@ -793,7 +793,7 @@ async def test__register_instance_state( ] ) - @pytest.mark.asyncio + @CrossSync.pytest async def test__remove_instance_registration(self): client = self._make_client(project="project-id") table = mock.Mock() @@ -824,7 +824,7 @@ async def test__remove_instance_registration(self): assert len(client._active_instances) == 1 await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__multiple_table_registration(self): """ registering with multiple tables with the same key should @@ -881,7 +881,7 @@ async def test__multiple_table_registration(self): assert instance_1_key not in client._active_instances assert len(client._instance_owners[instance_1_key]) == 0 - @pytest.mark.asyncio + @CrossSync.pytest async def test__multiple_instance_registration(self): """ registering with multiple instance keys should update the key @@ -928,7 +928,7 @@ async def test__multiple_instance_registration(self): assert len(client._instance_owners[instance_1_key]) == 0 assert len(client._instance_owners[instance_2_key]) == 0 - @pytest.mark.asyncio + @CrossSync.pytest async def test_get_table(self): from google.cloud.bigtable.data._helpers import _WarmedInstanceKey @@ -963,7 +963,7 @@ async def test_get_table(self): assert client._instance_owners[instance_key] == {id(table)} await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_get_table_arg_passthrough(self): """ All arguments passed in get_table should be sent to constructor @@ -996,7 +996,7 @@ async def test_get_table_arg_passthrough(self): **expected_kwargs, ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_get_table_context_manager(self): from google.cloud.bigtable.data._helpers import _WarmedInstanceKey @@ -1035,7 +1035,7 @@ async def test_get_table_context_manager(self): assert client._instance_owners[instance_key] == {id(table)} assert close_mock.call_count == 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_multiple_pool_sizes(self): # should be able to create multiple clients with different pool sizes without issue pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] @@ -1052,7 +1052,7 @@ async def test_multiple_pool_sizes(self): await client.close() await client_duplicate.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_close(self): pool_size = 7 client = self._make_client( @@ -1073,7 +1073,7 @@ async def test_close(self): assert task.done() assert client._channel_refresh_tasks == [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_close_with_timeout(self): pool_size = 7 expected_timeout = 19 @@ -1088,7 +1088,7 @@ async def test_close_with_timeout(self): client._channel_refresh_tasks = tasks await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_context_manager(self): # context manager should close the client cleanly close_mock = AsyncMock() @@ -1134,7 +1134,7 @@ def _make_client(self, *args, **kwargs): def _get_target_class(): return TableAsync - @pytest.mark.asyncio + @CrossSync.pytest async def test_table_ctor(self): from google.cloud.bigtable.data._helpers import _WarmedInstanceKey @@ -1197,7 +1197,7 @@ async def test_table_ctor(self): assert table._register_instance_future.exception() is None await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_table_ctor_defaults(self): """ should provide default timeout values and app_profile_id @@ -1227,7 +1227,7 @@ async def test_table_ctor_defaults(self): assert table.default_mutate_rows_attempt_timeout == 60 await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_table_ctor_invalid_timeout_values(self): """ bad timeout values should raise ValueError @@ -1266,7 +1266,7 @@ def test_table_ctor_sync(self): TableAsync(client, "instance-id", "table-id") assert e.match("TableAsync must be created within an async event loop context.") - @pytest.mark.asyncio + @CrossSync.pytest # iterate over all retryable rpcs @pytest.mark.parametrize( "fn_name,fn_args,is_stream,extra_retryables", @@ -1407,7 +1407,7 @@ async def test_customizable_retryable_errors( ], ) @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): """check that all requests attach proper metadata headers""" from google.cloud.bigtable.data import TableAsync @@ -1537,7 +1537,7 @@ def cancel(self): async def execute_fn(self, table, *args, **kwargs): return await table.read_rows(*args, **kwargs) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows(self): query = ReadRowsQuery() chunks = [ @@ -1554,7 +1554,7 @@ async def test_read_rows(self): assert results[0].row_key == b"test_1" assert results[1].row_key == b"test_2" - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_stream(self): query = ReadRowsQuery() chunks = [ @@ -1573,7 +1573,7 @@ async def test_read_rows_stream(self): assert results[1].row_key == b"test_2" @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_query_matches_request(self, include_app_profile): from google.cloud.bigtable.data import RowRange from google.cloud.bigtable.data.row_filters import PassAllFilter @@ -1600,7 +1600,7 @@ async def test_read_rows_query_matches_request(self, include_app_profile): assert call_request == query_pb @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_timeout(self, operation_timeout): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows @@ -1625,7 +1625,7 @@ async def test_read_rows_timeout(self, operation_timeout): (0.05, 0.24, 5), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_attempt_timeout( self, per_request_t, operation_t, expected_num ): @@ -1688,7 +1688,7 @@ async def test_read_rows_attempt_timeout( core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_retryable_error(self, exc_type): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows @@ -1719,7 +1719,7 @@ async def test_read_rows_retryable_error(self, exc_type): InvalidChunk, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_non_retryable_error(self, exc_type): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows @@ -1733,7 +1733,7 @@ async def test_read_rows_non_retryable_error(self, exc_type): except exc_type as e: assert e == expected_error - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_revise_request(self): """ Ensure that _revise_request is called between retries @@ -1767,7 +1767,7 @@ async def test_read_rows_revise_request(self): revised_call = read_rows.call_args_list[1].args[0] assert revised_call.rows == return_val - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_default_timeouts(self): """ Ensure that the default timeouts are set on the read rows operation when not overridden @@ -1788,7 +1788,7 @@ async def test_read_rows_default_timeouts(self): assert kwargs["operation_timeout"] == operation_timeout assert kwargs["attempt_timeout"] == attempt_timeout - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_default_timeout_override(self): """ When timeouts are passed, they overwrite default values @@ -1812,7 +1812,7 @@ async def test_read_rows_default_timeout_override(self): assert kwargs["operation_timeout"] == operation_timeout assert kwargs["attempt_timeout"] == attempt_timeout - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_row(self): """Test reading a single row""" async with self._make_client() as client: @@ -1840,7 +1840,7 @@ async def test_read_row(self): assert query.row_ranges == [] assert query.limit == 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_row_w_filter(self): """Test reading a single row with an added filter""" async with self._make_client() as client: @@ -1873,7 +1873,7 @@ async def test_read_row_w_filter(self): assert query.limit == 1 assert query.filter == expected_filter - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_row_no_response(self): """should return None if row does not exist""" async with self._make_client() as client: @@ -1908,7 +1908,7 @@ async def test_read_row_no_response(self): ([object(), object()], True), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_row_exists(self, return_value, expected_result): """Test checking for row existence""" async with self._make_client() as client: @@ -1953,7 +1953,7 @@ class TestReadRowsShardedAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_empty_query(self): async with self._make_client() as client: async with client.get_table("instance", "table") as table: @@ -1961,7 +1961,7 @@ async def test_read_rows_sharded_empty_query(self): await table.read_rows_sharded([]) assert "empty sharded_query" in str(exc.value) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_multiple_queries(self): """ Test with multiple queries. Should return results from both @@ -1987,7 +1987,7 @@ async def test_read_rows_sharded_multiple_queries(self): assert result[1].row_key == b"test_2" @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_multiple_queries_calls(self, n_queries): """ Each query should trigger a separate read_rows call @@ -1999,7 +1999,7 @@ async def test_read_rows_sharded_multiple_queries_calls(self, n_queries): await table.read_rows_sharded(query_list) assert read_rows.call_count == n_queries - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_errors(self): """ Errors should be exposed as ShardedReadRowsExceptionGroups @@ -2027,7 +2027,7 @@ async def test_read_rows_sharded_errors(self): assert exc.value.exceptions[1].index == 1 assert exc.value.exceptions[1].query == query_2 - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_concurrent(self): """ Ensure sharded requests are concurrent @@ -2051,7 +2051,7 @@ async def mock_call(*args, **kwargs): # if run in sequence, we would expect this to take 1 second assert call_time < 0.2 - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_concurrency_limit(self): """ Only 10 queries should be processed concurrently. Others should be queued @@ -2100,7 +2100,7 @@ async def mock_call(*args, **kwargs): idx = i + _CONCURRENCY_LIMIT assert rpc_start_list[idx] - (i * increment_time) < eps - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_expirary(self): """ If the operation times out before all shards complete, should raise @@ -2140,7 +2140,7 @@ async def mock_call(*args, **kwargs): # should keep successful queries assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_negative_batch_timeout(self): """ try to run with batch that starts after operation timeout @@ -2183,7 +2183,7 @@ async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): for value in sample_list: yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys(self): """ Test that method returns the expected key samples @@ -2208,7 +2208,7 @@ async def test_sample_row_keys(self): assert result[1] == samples[1] assert result[2] == samples[2] - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_bad_timeout(self): """ should raise error if timeout is negative @@ -2222,7 +2222,7 @@ async def test_sample_row_keys_bad_timeout(self): await table.sample_row_keys(attempt_timeout=-1) assert "attempt_timeout must be greater than 0" in str(e.value) - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_default_timeout(self): """Should fallback to using table default operation_timeout""" expected_timeout = 99 @@ -2243,7 +2243,7 @@ async def test_sample_row_keys_default_timeout(self): assert result == [] assert kwargs["retry"] is None - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_gapic_params(self): """ make sure arguments are propagated to gapic call as expected @@ -2277,7 +2277,7 @@ async def test_sample_row_keys_gapic_params(self): core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_retryable_errors(self, retryable_exception): """ retryable errors should be retried until timeout @@ -2309,7 +2309,7 @@ async def test_sample_row_keys_retryable_errors(self, retryable_exception): core_exceptions.Aborted, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): """ non-retryable errors should cause a raise @@ -2332,7 +2332,7 @@ class TestMutateRowAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "mutation_arg", [ @@ -2388,7 +2388,7 @@ async def test_mutate_row(self, mutation_arg): core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_retryable_errors(self, retryable_exception): from google.api_core.exceptions import DeadlineExceeded from google.cloud.bigtable.data.exceptions import RetryExceptionGroup @@ -2416,7 +2416,7 @@ async def test_mutate_row_retryable_errors(self, retryable_exception): core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_non_idempotent_retryable_errors( self, retryable_exception ): @@ -2449,7 +2449,7 @@ async def test_mutate_row_non_idempotent_retryable_errors( core_exceptions.Aborted, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: @@ -2470,7 +2470,7 @@ async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): ) @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_metadata(self, include_app_profile): """request should attach metadata headers""" profile = "profile" if include_app_profile else None @@ -2494,7 +2494,7 @@ async def test_mutate_row_metadata(self, include_app_profile): assert "app_profile_id=" not in goog_metadata @pytest.mark.parametrize("mutations", [[], None]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_no_mutations(self, mutations): async with self._make_client() as client: async with client.get_table("instance", "table") as table: @@ -2535,8 +2535,8 @@ async def generator(): return generator() - @pytest.mark.asyncio - @pytest.mark.asyncio + @CrossSync.pytest + @CrossSync.pytest @pytest.mark.parametrize( "mutation_arg", [ @@ -2580,7 +2580,7 @@ async def test_bulk_mutate_rows(self, mutation_arg): assert kwargs["timeout"] == expected_attempt_timeout assert kwargs["retry"] is None - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_rows_multiple_entries(self): """Test mutations with no errors""" async with self._make_client(project="project") as client: @@ -2604,7 +2604,7 @@ async def test_bulk_mutate_rows_multiple_entries(self): assert kwargs["entries"][0] == entry_1._to_pb() assert kwargs["entries"][1] == entry_2._to_pb() - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "exception", [ @@ -2649,7 +2649,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_retryable( cause.exceptions[-1], core_exceptions.DeadlineExceeded ) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "exception", [ @@ -2697,7 +2697,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable( core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_idempotent_retryable_request_errors( self, retryable_exception ): @@ -2731,7 +2731,7 @@ async def test_bulk_mutate_idempotent_retryable_request_errors( assert isinstance(cause, RetryExceptionGroup) assert isinstance(cause.exceptions[0], retryable_exception) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "retryable_exception", [ @@ -2780,7 +2780,7 @@ async def test_bulk_mutate_rows_non_idempotent_retryable_errors( ValueError, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception): """ If the request fails with a non-retryable error, mutations should not be retried @@ -2810,7 +2810,7 @@ async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_excepti cause = failed_exception.__cause__ assert isinstance(cause, non_retryable_exception) - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_error_index(self): """ Test partial failure, partial success. Errors should be associated with the correct index @@ -2861,7 +2861,7 @@ async def test_bulk_mutate_error_index(self): assert isinstance(cause.exceptions[1], DeadlineExceeded) assert isinstance(cause.exceptions[2], FailedPrecondition) - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_error_recovery(self): """ If an error occurs, then resolves, no exception should be raised @@ -2895,7 +2895,7 @@ def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @pytest.mark.parametrize("gapic_result", [True, False]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate(self, gapic_result): from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse @@ -2937,7 +2937,7 @@ async def test_check_and_mutate(self, gapic_result): assert kwargs["timeout"] == operation_timeout assert kwargs["retry"] is None - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate_bad_timeout(self): """Should raise error if operation_timeout < 0""" async with self._make_client() as client: @@ -2952,7 +2952,7 @@ async def test_check_and_mutate_bad_timeout(self): ) assert str(e.value) == "operation_timeout must be greater than 0" - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate_single_mutations(self): """if single mutations are passed, they should be internally wrapped in a list""" from google.cloud.bigtable.data.mutations import SetCell @@ -2978,7 +2978,7 @@ async def test_check_and_mutate_single_mutations(self): assert kwargs["true_mutations"] == [true_mutation._to_pb()] assert kwargs["false_mutations"] == [false_mutation._to_pb()] - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate_predicate_object(self): """predicate filter should be passed to gapic request""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse @@ -3004,7 +3004,7 @@ async def test_check_and_mutate_predicate_object(self): assert mock_predicate._to_pb.call_count == 1 assert kwargs["retry"] is None - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate_mutations_parsing(self): """mutations objects should be converted to protos""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse @@ -3070,7 +3070,7 @@ def _make_client(self, *args, **kwargs): ), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): """ Test that the gapic call is called with given rules @@ -3087,7 +3087,7 @@ async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules assert found_kwargs["retry"] is None @pytest.mark.parametrize("rules", [[], None]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_no_rules(self, rules): async with self._make_client() as client: async with client.get_table("instance", "table") as table: @@ -3095,7 +3095,7 @@ async def test_read_modify_write_no_rules(self, rules): await table.read_modify_write_row("key", rules=rules) assert e.value.args[0] == "rules must contain at least one item" - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_call_defaults(self): instance = "instance1" table_id = "table1" @@ -3117,7 +3117,7 @@ async def test_read_modify_write_call_defaults(self): assert kwargs["row_key"] == row_key.encode() assert kwargs["timeout"] > 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_call_overrides(self): row_key = b"row_key1" expected_timeout = 12345 @@ -3140,7 +3140,7 @@ async def test_read_modify_write_call_overrides(self): assert kwargs["row_key"] == row_key assert kwargs["timeout"] == expected_timeout - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_string_key(self): row_key = "string_row_key1" async with self._make_client() as client: @@ -3153,7 +3153,7 @@ async def test_read_modify_write_string_key(self): kwargs = mock_gapic.call_args_list[0][1] assert kwargs["row_key"] == row_key.encode() - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_row_building(self): """ results from gapic call should be used to construct row diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index d7f5b68cf..8ef05326d 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -122,7 +122,7 @@ def test__has_capacity( instance._in_flight_mutation_bytes = existing_size assert instance._has_capacity(new_count, new_size) == expected - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "existing_count,existing_size,added_count,added_size,new_count,new_size", [ @@ -155,7 +155,7 @@ async def test_remove_from_flow_value_update( assert instance._in_flight_mutation_count == new_count assert instance._in_flight_mutation_bytes == new_size - @pytest.mark.asyncio + @CrossSync.pytest async def test__remove_from_flow_unlock(self): """capacity condition should notify after mutation is complete""" import inspect @@ -210,7 +210,7 @@ async def task_routine(): # task should be complete assert task_alive() is False - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "mutations,count_cap,size_cap,expected_results", [ @@ -251,7 +251,7 @@ async def test_add_to_flow(self, mutations, count_cap, size_cap, expected_result i += 1 assert i == len(expected_results) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "mutations,max_limit,expected_results", [ @@ -295,7 +295,7 @@ async def test_add_to_flow_max_mutation_limits( i += 1 assert i == len(expected_results) - @pytest.mark.asyncio + @CrossSync.pytest async def test_add_to_flow_oversize(self): """ mutations over the flow control limits should still be accepted @@ -349,7 +349,7 @@ def _make_mutation(count=1, size=1): mutation.mutations = [mock.Mock()] * count return mutation - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor_defaults(self): with mock.patch.object( self._get_target_class(), "_timer_routine", return_value=asyncio.Future() @@ -389,7 +389,7 @@ async def test_ctor_defaults(self): assert flush_timer_mock.call_args[0][0] == 5 assert isinstance(instance._flush_timer, asyncio.Future) - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor_explicit(self): """Test with explicit parameters""" with mock.patch.object( @@ -441,7 +441,7 @@ async def test_ctor_explicit(self): assert flush_timer_mock.call_args[0][0] == flush_interval assert isinstance(instance._flush_timer, asyncio.Future) - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor_no_flush_limits(self): """Test with None for flush limits""" with mock.patch.object( @@ -475,7 +475,7 @@ async def test_ctor_no_flush_limits(self): assert flush_timer_mock.call_args[0][0] is None assert isinstance(instance._flush_timer, asyncio.Future) - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor_invalid_values(self): """Test that timeout values are positive, and fit within expected limits""" with pytest.raises(ValueError) as e: @@ -513,7 +513,7 @@ def test_default_argument_consistency(self): == batcher_init_signature[arg_name].default ) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize("input_val", [None, 0, -1]) async def test__start_flush_timer_w_empty_input(self, input_val): """Empty/invalid timer should return immediately""" @@ -532,7 +532,7 @@ async def test__start_flush_timer_w_empty_input(self, input_val): assert flush_mock.call_count == 0 assert result is None - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test__start_flush_timer_call_when_closed( self, @@ -554,7 +554,7 @@ async def test__start_flush_timer_call_when_closed( assert sleep_mock.call_count == 0 assert flush_mock.call_count == 0 - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize("num_staged", [0, 1, 10]) @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test__flush_timer(self, num_staged): @@ -581,7 +581,7 @@ async def test__flush_timer(self, num_staged): assert sleep_kwargs["timeout"] == expected_sleep assert flush_mock.call_count == (0 if num_staged == 0 else loop_num) - @pytest.mark.asyncio + @CrossSync.pytest async def test__flush_timer_close(self): """Timer should continue terminate after close""" with mock.patch.object(self._get_target_class(), "_schedule_flush"): @@ -596,7 +596,7 @@ async def test__flush_timer_close(self): # task should be complete assert instance._flush_timer.done() is True - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_closed(self): """Should raise exception""" instance = self._make_one() @@ -604,7 +604,7 @@ async def test_append_closed(self): with pytest.raises(RuntimeError): await instance.append(mock.Mock()) - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_wrong_mutation(self): """ Mutation objects should raise an exception. @@ -618,7 +618,7 @@ async def test_append_wrong_mutation(self): await instance.append(DeleteAllFromRow()) assert str(e.value) == expected_error - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_outside_flow_limits(self): """entries larger than mutation limits are still processed""" async with self._make_one( @@ -640,7 +640,7 @@ async def test_append_outside_flow_limits(self): assert instance._staged_bytes == 0 instance._staged_entries = [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_flush_runs_after_limit_hit(self): """ If the user appends a bunch of entries above the flush limits back-to-back, @@ -682,7 +682,7 @@ async def mock_call(*args, **kwargs): (1, 1, 0, 0, False), ], ) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_append( self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush @@ -703,7 +703,7 @@ async def test_append( assert instance._staged_entries == [mutation] instance._staged_entries = [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_multiple_sequentially(self): """Append multiple mutations""" async with self._make_one( @@ -731,7 +731,7 @@ async def test_append_multiple_sequentially(self): assert len(instance._staged_entries) == 3 instance._staged_entries = [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_flush_flow_control_concurrent_requests(self): """ requests should happen in parallel if flow control breaks up single flush into batches @@ -770,7 +770,7 @@ async def mock_call(*args, **kwargs): assert duration < 0.5 assert op_mock.call_count == num_calls - @pytest.mark.asyncio + @CrossSync.pytest async def test_schedule_flush_no_mutations(self): """schedule flush should return None if no staged mutations""" async with self._make_one() as instance: @@ -779,7 +779,7 @@ async def test_schedule_flush_no_mutations(self): assert instance._schedule_flush() is None assert flush_mock.call_count == 0 - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_schedule_flush_with_mutations(self): """if new mutations exist, should add a new flush task to _flush_jobs""" @@ -801,7 +801,7 @@ async def test_schedule_flush_with_mutations(self): assert flush_mock.call_count == 1 flush_mock.reset_mock() - @pytest.mark.asyncio + @CrossSync.pytest async def test__flush_internal(self): """ _flush_internal should: @@ -829,7 +829,7 @@ async def gen(x): instance._oldest_exceptions.clear() instance._newest_exceptions.clear() - @pytest.mark.asyncio + @CrossSync.pytest async def test_flush_clears_job_list(self): """ a job should be added to _flush_jobs when _schedule_flush is called, @@ -865,7 +865,7 @@ async def test_flush_clears_job_list(self): (10, 20, 20), # should cap at 20 ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test__flush_internal_with_errors( self, num_starting, num_new_errors, expected_total_errors ): @@ -920,7 +920,7 @@ async def gen(num): return gen(num) - @pytest.mark.asyncio + @CrossSync.pytest async def test_timer_flush_end_to_end(self): """Flush should automatically trigger after flush_interval""" num_nutations = 10 @@ -942,7 +942,7 @@ async def test_timer_flush_end_to_end(self): await asyncio.sleep(0.1) assert instance._entries_processed_since_last_raise == num_nutations - @pytest.mark.asyncio + @CrossSync.pytest async def test__execute_mutate_rows(self): if self.is_async(): mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" @@ -969,7 +969,7 @@ async def test__execute_mutate_rows(self): kwargs["attempt_timeout"] == 13 assert result == [] - @pytest.mark.asyncio + @CrossSync.pytest async def test__execute_mutate_rows_returns_errors(self): """Errors from operation should be retruned as list""" from google.cloud.bigtable.data.exceptions import ( @@ -1001,7 +1001,7 @@ async def test__execute_mutate_rows_returns_errors(self): assert result[0].index is None assert result[1].index is None - @pytest.mark.asyncio + @CrossSync.pytest async def test__raise_exceptions(self): """Raise exceptions and reset error state""" from google.cloud.bigtable.data import exceptions @@ -1021,13 +1021,13 @@ async def test__raise_exceptions(self): # try calling again instance._raise_exceptions() - @pytest.mark.asyncio + @CrossSync.pytest async def test___aenter__(self): """Should return self""" async with self._make_one() as instance: assert await instance.__aenter__() == instance - @pytest.mark.asyncio + @CrossSync.pytest async def test___aexit__(self): """aexit should call close""" async with self._make_one() as instance: @@ -1035,7 +1035,7 @@ async def test___aexit__(self): await instance.__aexit__(None, None, None) assert close_mock.call_count == 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_close(self): """Should clean up all resources""" async with self._make_one() as instance: @@ -1048,7 +1048,7 @@ async def test_close(self): assert flush_mock.call_count == 1 assert raise_mock.call_count == 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_close_w_exceptions(self): """Raise exceptions on close""" from google.cloud.bigtable.data import exceptions @@ -1067,7 +1067,7 @@ async def test_close_w_exceptions(self): # clear out exceptions instance._oldest_exceptions, instance._newest_exceptions = ([], []) - @pytest.mark.asyncio + @CrossSync.pytest async def test__on_exit(self, recwarn): """Should raise warnings if unflushed mutations exist""" async with self._make_one() as instance: @@ -1089,7 +1089,7 @@ async def test__on_exit(self, recwarn): # reset staged mutations for cleanup instance._staged_entries = [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_atexit_registration(self): """Should run _on_exit on program termination""" import atexit @@ -1099,7 +1099,7 @@ async def test_atexit_registration(self): async with self._make_one(): assert register_mock.call_count == 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_timeout_args_passed(self): """ batch_operation_timeout and batch_attempt_timeout should be used @@ -1186,7 +1186,7 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): for i in range(1, newest_list_diff + 1): assert mock_batcher._newest_exceptions[-i] == input_list[-i] - @pytest.mark.asyncio + @CrossSync.pytest # test different inputs for retryable exceptions @pytest.mark.parametrize( "input_retryables,expected_retryables", diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 1779ae2ac..c8b49bdab 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -117,7 +117,7 @@ async def _row_stream(): @pytest.mark.parametrize( "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_row_merger_scenario(self, test_case: ReadRowsTest): async def _scenerio_stream(): for chunk in test_case.chunks: @@ -151,7 +151,7 @@ async def _scenerio_stream(): @pytest.mark.parametrize( "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_scenario(self, test_case: ReadRowsTest): async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): from google.cloud.bigtable_v2 import ReadRowsResponse @@ -207,7 +207,7 @@ def cancel(self): for expected, actual in zip_longest(test_case.results, results): assert actual == expected - @pytest.mark.asyncio + @CrossSync.pytest async def test_out_of_order_rows(self): async def _row_stream(): yield ReadRowsResponse(last_scanned_row_key=b"a") @@ -223,7 +223,7 @@ async def _row_stream(): async for _ in merger: pass - @pytest.mark.asyncio + @CrossSync.pytest async def test_bare_reset(self): first_chunk = ReadRowsResponse.CellChunk( ReadRowsResponse.CellChunk( @@ -273,7 +273,7 @@ async def test_bare_reset(self): ), ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_missing_family(self): with pytest.raises(InvalidChunk): await self._process_chunks( @@ -286,7 +286,7 @@ async def test_missing_family(self): ) ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mid_cell_row_key_change(self): with pytest.raises(InvalidChunk): await self._process_chunks( @@ -301,7 +301,7 @@ async def test_mid_cell_row_key_change(self): ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mid_cell_family_change(self): with pytest.raises(InvalidChunk): await self._process_chunks( @@ -318,7 +318,7 @@ async def test_mid_cell_family_change(self): ), ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mid_cell_qualifier_change(self): with pytest.raises(InvalidChunk): await self._process_chunks( @@ -335,7 +335,7 @@ async def test_mid_cell_qualifier_change(self): ), ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mid_cell_timestamp_change(self): with pytest.raises(InvalidChunk): await self._process_chunks( @@ -352,7 +352,7 @@ async def test_mid_cell_timestamp_change(self): ), ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mid_cell_labels_change(self): with pytest.raises(InvalidChunk): await self._process_chunks( diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 9df88dad3..6194cbc15 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -17,7 +17,6 @@ import grpc import asyncio import re -import sys import pytest from google.cloud.bigtable.data import mutations from google.auth.credentials import AnonymousCredentials @@ -952,1767 +951,1767 @@ def test_context_manager(self): true_close -class TestBulkMutateRows: +class TestTable: def _make_client(self, *args, **kwargs): return TestBigtableDataClient._make_client(*args, **kwargs) - def _mock_response(self, response_list): - from google.cloud.bigtable_v2.types import MutateRowsResponse - from google.rpc import status_pb2 + @staticmethod + def _get_target_class(): + return Table - statuses = [] - for response in response_list: - if isinstance(response, core_exceptions.GoogleAPICallError): - statuses.append( - status_pb2.Status( - message=str(response), code=response.grpc_status_code.value[0] - ) - ) - else: - statuses.append(status_pb2.Status(code=0)) - entries = [ - MutateRowsResponse.Entry(index=i, status=statuses[i]) - for i in range(len(response_list)) - ] + def test_table_ctor(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - def generator(): - yield MutateRowsResponse(entries=entries) + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_operation_timeout = 123 + expected_attempt_timeout = 12 + expected_read_rows_operation_timeout = 1.5 + expected_read_rows_attempt_timeout = 0.5 + expected_mutate_rows_operation_timeout = 2.5 + expected_mutate_rows_attempt_timeout = 0.75 + client = self._make_client() + assert not client._active_instances + table = self._get_target_class()( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + default_operation_timeout=expected_operation_timeout, + default_attempt_timeout=expected_attempt_timeout, + default_read_rows_operation_timeout=expected_read_rows_operation_timeout, + default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout, + default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, + ) + CrossSync._Sync_Impl.yield_to_event_loop() + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + assert table.default_operation_timeout == expected_operation_timeout + assert table.default_attempt_timeout == expected_attempt_timeout + assert ( + table.default_read_rows_operation_timeout + == expected_read_rows_operation_timeout + ) + assert ( + table.default_read_rows_attempt_timeout + == expected_read_rows_attempt_timeout + ) + assert ( + table.default_mutate_rows_operation_timeout + == expected_mutate_rows_operation_timeout + ) + assert ( + table.default_mutate_rows_attempt_timeout + == expected_mutate_rows_attempt_timeout + ) + table._register_instance_future + assert table._register_instance_future.done() + assert not table._register_instance_future.cancelled() + assert table._register_instance_future.exception() is None + client.close() - return generator() + def test_table_ctor_defaults(self): + """should provide default timeout values and app_profile_id""" + expected_table_id = "table-id" + expected_instance_id = "instance-id" + client = self._make_client() + assert not client._active_instances + table = Table(client, expected_instance_id, expected_table_id) + CrossSync._Sync_Impl.yield_to_event_loop() + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id is None + assert table.client is client + assert table.default_operation_timeout == 60 + assert table.default_read_rows_operation_timeout == 600 + assert table.default_mutate_rows_operation_timeout == 600 + assert table.default_attempt_timeout == 20 + assert table.default_read_rows_attempt_timeout == 20 + assert table.default_mutate_rows_attempt_timeout == 60 + client.close() + + def test_table_ctor_invalid_timeout_values(self): + """bad timeout values should raise ValueError""" + client = self._make_client() + timeout_pairs = [ + ("default_operation_timeout", "default_attempt_timeout"), + ( + "default_read_rows_operation_timeout", + "default_read_rows_attempt_timeout", + ), + ( + "default_mutate_rows_operation_timeout", + "default_mutate_rows_attempt_timeout", + ), + ] + for operation_timeout, attempt_timeout in timeout_pairs: + with pytest.raises(ValueError) as e: + Table(client, "", "", **{attempt_timeout: -1}) + assert "attempt_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + Table(client, "", "", **{operation_timeout: -1}) + assert "operation_timeout must be greater than 0" in str(e.value) + client.close() @pytest.mark.parametrize( - "mutation_arg", + "fn_name,fn_args,is_stream,extra_retryables", [ - [mutations.SetCell("family", b"qualifier", b"value")], - [ - mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=1234567890 - ) - ], - [mutations.DeleteRangeFromColumn("family", b"qualifier")], - [mutations.DeleteAllFromFamily("family")], - [mutations.DeleteAllFromRow()], - [mutations.SetCell("family", b"qualifier", b"value")], - [ - mutations.DeleteRangeFromColumn("family", b"qualifier"), - mutations.DeleteAllFromRow(), - ], + ("read_rows_stream", (ReadRowsQuery(),), True, ()), + ("read_rows", (ReadRowsQuery(),), True, ()), + ("read_row", (b"row_key",), True, ()), + ("read_rows_sharded", ([ReadRowsQuery()],), True, ()), + ("row_exists", (b"row_key",), True, ()), + ("sample_row_keys", (), False, ()), + ("mutate_row", (b"row_key", [mock.Mock()]), False, ()), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + False, + (_MutateRowsIncomplete,), + ), ], ) - def test_bulk_mutate_rows(self, mutation_arg): - """Test mutations with no errors""" - expected_attempt_timeout = 19 - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.return_value = self._mock_response([None]) - bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg) - table.bulk_mutate_rows( - [bulk_mutation], attempt_timeout=expected_attempt_timeout - ) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args[1] - assert ( - kwargs["table_name"] - == "projects/project/instances/instance/tables/table" - ) - assert kwargs["entries"] == [bulk_mutation._to_pb()] - assert kwargs["timeout"] == expected_attempt_timeout - assert kwargs["retry"] is None - - def test_bulk_mutate_rows_multiple_entries(self): - """Test mutations with no errors""" - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.return_value = self._mock_response([None, None]) - mutation_list = [mutations.DeleteAllFromRow()] - entry_1 = mutations.RowMutationEntry(b"row_key_1", mutation_list) - entry_2 = mutations.RowMutationEntry(b"row_key_2", mutation_list) - table.bulk_mutate_rows([entry_1, entry_2]) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args[1] - assert ( - kwargs["table_name"] - == "projects/project/instances/instance/tables/table" + @pytest.mark.parametrize( + "input_retryables,expected_retryables", + [ + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), + ], + ) + def test_customizable_retryable_errors( + self, + input_retryables, + expected_retryables, + fn_name, + fn_args, + is_stream, + extra_retryables, + ): + """Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer.""" + retry_fn = "retry_target" + if is_stream: + retry_fn += "_stream" + if CrossSync._Sync_Impl.is_async: + retry_fn = f"CrossSync.{retry_fn}" + else: + retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" + with mock.patch( + f"google.cloud.bigtable.data._sync.cross_sync.{retry_fn}" + ) as retry_fn_mock: + with self._make_client() as client: + table = client.get_table("instance-id", "table-id") + expected_predicate = lambda a: a in expected_retryables + retry_fn_mock.side_effect = RuntimeError("stop early") + with mock.patch( + "google.api_core.retry.if_exception_type" + ) as predicate_builder_mock: + predicate_builder_mock.return_value = expected_predicate + with pytest.raises(Exception): + test_fn = table.__getattribute__(fn_name) + test_fn(*fn_args, retryable_errors=input_retryables) + predicate_builder_mock.assert_called_once_with( + *expected_retryables, *extra_retryables ) - assert kwargs["entries"][0] == entry_1._to_pb() - assert kwargs["entries"][1] == entry_2._to_pb() + retry_call_args = retry_fn_mock.call_args_list[0].args + assert retry_call_args[1] is expected_predicate @pytest.mark.parametrize( - "exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_bulk_mutate_rows_idempotent_mutation_error_retryable(self, exception): - """Individual idempotent mutations should be retried if they fail with a retryable error""" - from google.cloud.bigtable.data.exceptions import ( - RetryExceptionGroup, - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = lambda *a, **k: self._mock_response( - [exception("mock")] - ) - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.DeleteAllFromRow() - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - table.bulk_mutate_rows([entry], operation_timeout=0.05) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert "non-idempotent" not in str(failed_exception) - assert isinstance(failed_exception, FailedMutationEntryError) - cause = failed_exception.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert isinstance(cause.exceptions[0], exception) - assert isinstance( - cause.exceptions[-1], core_exceptions.DeadlineExceeded - ) - - @pytest.mark.parametrize( - "exception", + "fn_name,fn_args,gapic_fn", [ - core_exceptions.OutOfRange, - core_exceptions.NotFound, - core_exceptions.FailedPrecondition, - core_exceptions.Aborted, + ("read_rows_stream", (ReadRowsQuery(),), "read_rows"), + ("read_rows", (ReadRowsQuery(),), "read_rows"), + ("read_row", (b"row_key",), "read_rows"), + ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), + ("row_exists", (b"row_key",), "read_rows"), + ("sample_row_keys", (), "sample_row_keys"), + ("mutate_row", (b"row_key", [mock.Mock()]), "mutate_row"), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + "mutate_rows", + ), + ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), + ( + "read_modify_write_row", + (b"row_key", mock.Mock()), + "read_modify_write_row", + ), ], ) - def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable(self, exception): - """Individual idempotent mutations should not be retried if they fail with a non-retryable error""" - from google.cloud.bigtable.data.exceptions import ( - FailedMutationEntryError, - MutationsExceptionGroup, - ) + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): + """check that all requests attach proper metadata headers""" + profile = "profile" if include_app_profile else None + with mock.patch.object(BigtableClient, gapic_fn, mock.Mock()) as gapic_mock: + gapic_mock.side_effect = RuntimeError("stop early") + with self._make_client() as client: + table = Table(client, "instance-id", "table-id", profile) + try: + test_fn = table.__getattribute__(fn_name) + maybe_stream = test_fn(*fn_args) + [i for i in maybe_stream] + except Exception: + pass + kwargs = gapic_mock.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = lambda *a, **k: self._mock_response( - [exception("mock")] - ) - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.DeleteAllFromRow() - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - table.bulk_mutate_rows([entry], operation_timeout=0.05) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert "non-idempotent" not in str(failed_exception) - assert isinstance(failed_exception, FailedMutationEntryError) - cause = failed_exception.__cause__ - assert isinstance(cause, exception) - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_bulk_mutate_idempotent_retryable_request_errors(self, retryable_exception): - """Individual idempotent mutations should be retried if the request fails with a retryable error""" - from google.cloud.bigtable.data.exceptions import ( - RetryExceptionGroup, - FailedMutationEntryError, - MutationsExceptionGroup, - ) +class TestReadRows: + """ + Tests for table.read_rows and related methods. + """ - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - table.bulk_mutate_rows([entry], operation_timeout=0.05) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" not in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert isinstance(cause.exceptions[0], retryable_exception) + @staticmethod + def _get_operation_class(): + return _ReadRowsOperation - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_bulk_mutate_rows_non_idempotent_retryable_errors( - self, retryable_exception - ): - """Non-Idempotent mutations should never be retried""" - from google.cloud.bigtable.data.exceptions import ( - FailedMutationEntryError, - MutationsExceptionGroup, + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + def _make_table(self, *args, **kwargs): + client_mock = mock.Mock() + client_mock._register_instance.side_effect = ( + lambda *args, **kwargs: CrossSync._Sync_Impl.yield_to_event_loop() + ) + client_mock._remove_instance_registration.side_effect = ( + lambda *args, **kwargs: CrossSync._Sync_Impl.yield_to_event_loop() + ) + kwargs["instance_id"] = kwargs.get( + "instance_id", args[0] if args else "instance" ) + kwargs["table_id"] = kwargs.get( + "table_id", args[1] if len(args) > 1 else "table" + ) + client_mock._gapic_client.table_path.return_value = kwargs["table_id"] + client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] + return TestTable._get_target_class()(client_mock, *args, **kwargs) - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = lambda *a, **k: self._mock_response( - [retryable_exception("mock")] - ) - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", -1 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is False - table.bulk_mutate_rows([entry], operation_timeout=0.2) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, retryable_exception) + def _make_stats(self): + from google.cloud.bigtable_v2.types import RequestStats + from google.cloud.bigtable_v2.types import FullReadStatsView + from google.cloud.bigtable_v2.types import ReadIterationStats - @pytest.mark.parametrize( - "non_retryable_exception", - [ - core_exceptions.OutOfRange, - core_exceptions.NotFound, - core_exceptions.FailedPrecondition, - RuntimeError, - ValueError, - ], - ) - def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception): - """If the request fails with a non-retryable error, mutations should not be retried""" - from google.cloud.bigtable.data.exceptions import ( - FailedMutationEntryError, - MutationsExceptionGroup, + return RequestStats( + full_read_stats_view=FullReadStatsView( + read_iteration_stats=ReadIterationStats( + rows_seen_count=1, + rows_returned_count=2, + cells_seen_count=3, + cells_returned_count=4, + ) + ) ) - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = non_retryable_exception("mock") - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - table.bulk_mutate_rows([entry], operation_timeout=0.2) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" not in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, non_retryable_exception) + @staticmethod + def _make_chunk(*args, **kwargs): + from google.cloud.bigtable_v2 import ReadRowsResponse - def test_bulk_mutate_error_index(self): - """Test partial failure, partial success. Errors should be associated with the correct index""" - from google.api_core.exceptions import ( - DeadlineExceeded, - ServiceUnavailable, - FailedPrecondition, - ) - from google.cloud.bigtable.data.exceptions import ( - RetryExceptionGroup, - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = [ - self._mock_response([None, ServiceUnavailable("mock"), None]), - self._mock_response([DeadlineExceeded("mock")]), - self._mock_response([FailedPrecondition("final")]), - ] - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entries = [ - mutations.RowMutationEntry( - f"row_key_{i}".encode(), [mutation] - ) - for i in range(3) - ] - assert mutation.is_idempotent() is True - table.bulk_mutate_rows(entries, operation_timeout=1000) - assert len(e.value.exceptions) == 1 - failed = e.value.exceptions[0] - assert isinstance(failed, FailedMutationEntryError) - assert failed.index == 1 - assert failed.entry == entries[1] - cause = failed.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert len(cause.exceptions) == 3 - assert isinstance(cause.exceptions[0], ServiceUnavailable) - assert isinstance(cause.exceptions[1], DeadlineExceeded) - assert isinstance(cause.exceptions[2], FailedPrecondition) - - def test_bulk_mutate_error_recovery(self): - """If an error occurs, then resolves, no exception should be raised""" - from google.api_core.exceptions import DeadlineExceeded + kwargs["row_key"] = kwargs.get("row_key", b"row_key") + kwargs["family_name"] = kwargs.get("family_name", "family_name") + kwargs["qualifier"] = kwargs.get("qualifier", b"qualifier") + kwargs["value"] = kwargs.get("value", b"value") + kwargs["commit_row"] = kwargs.get("commit_row", True) + return ReadRowsResponse.CellChunk(*args, **kwargs) - with self._make_client(project="project") as client: - table = client.get_table("instance", "table") - with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: - mock_gapic.side_effect = [ - self._mock_response([DeadlineExceeded("mock")]), - self._mock_response([None]), - ] - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entries = [ - mutations.RowMutationEntry(f"row_key_{i}".encode(), [mutation]) - for i in range(3) - ] - table.bulk_mutate_rows(entries, operation_timeout=1000) + @staticmethod + def _make_gapic_stream( + chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0 + ): + from google.cloud.bigtable_v2 import ReadRowsResponse + class mock_stream: + def __init__(self, chunk_list, sleep_time): + self.chunk_list = chunk_list + self.idx = -1 + self.sleep_time = sleep_time -class TestCheckAndMutateRow: - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) + def __aiter__(self): + return self - @pytest.mark.parametrize("gapic_result", [True, False]) - def test_check_and_mutate(self, gapic_result): - from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + def __anext__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + if sleep_time: + CrossSync._Sync_Impl.sleep(self.sleep_time) + chunk = self.chunk_list[self.idx] + if isinstance(chunk, Exception): + raise chunk + else: + return ReadRowsResponse(chunks=[chunk]) + raise StopIteration - app_profile = "app_profile_id" - with self._make_client() as client: - with client.get_table( - "instance", "table", app_profile_id=app_profile - ) as table: - with mock.patch.object( - client._gapic_client, "check_and_mutate_row" - ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=gapic_result - ) - row_key = b"row_key" - predicate = None - true_mutations = [mock.Mock()] - false_mutations = [mock.Mock(), mock.Mock()] - operation_timeout = 0.2 - found = table.check_and_mutate_row( - row_key, - predicate, - true_case_mutations=true_mutations, - false_case_mutations=false_mutations, - operation_timeout=operation_timeout, - ) - assert found == gapic_result - kwargs = mock_gapic.call_args[1] - assert kwargs["table_name"] == table.table_name - assert kwargs["row_key"] == row_key - assert kwargs["predicate_filter"] == predicate - assert kwargs["true_mutations"] == [ - m._to_pb() for m in true_mutations - ] - assert kwargs["false_mutations"] == [ - m._to_pb() for m in false_mutations - ] - assert kwargs["app_profile_id"] == app_profile - assert kwargs["timeout"] == operation_timeout - assert kwargs["retry"] is None + def cancel(self): + pass - def test_check_and_mutate_bad_timeout(self): - """Should raise error if operation_timeout < 0""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as e: - table.check_and_mutate_row( - b"row_key", - None, - true_case_mutations=[mock.Mock()], - false_case_mutations=[], - operation_timeout=-1, - ) - assert str(e.value) == "operation_timeout must be greater than 0" + return mock_stream(chunk_list, sleep_time) - def test_check_and_mutate_single_mutations(self): - """if single mutations are passed, they should be internally wrapped in a list""" - from google.cloud.bigtable.data.mutations import SetCell - from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + def execute_fn(self, table, *args, **kwargs): + return table.read_rows(*args, **kwargs) - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "check_and_mutate_row" - ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True - ) - true_mutation = SetCell("family", b"qualifier", b"value") - false_mutation = SetCell("family", b"qualifier", b"value") - table.check_and_mutate_row( - b"row_key", - None, - true_case_mutations=true_mutation, - false_case_mutations=false_mutation, - ) - kwargs = mock_gapic.call_args[1] - assert kwargs["true_mutations"] == [true_mutation._to_pb()] - assert kwargs["false_mutations"] == [false_mutation._to_pb()] + def test_read_rows(self): + query = ReadRowsQuery() + chunks = [ + self._make_chunk(row_key=b"test_1"), + self._make_chunk(row_key=b"test_2"), + ] + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + results = self.execute_fn(table, query, operation_timeout=3) + assert len(results) == 2 + assert results[0].row_key == b"test_1" + assert results[1].row_key == b"test_2" - def test_check_and_mutate_predicate_object(self): - """predicate filter should be passed to gapic request""" - from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + def test_read_rows_stream(self): + query = ReadRowsQuery() + chunks = [ + self._make_chunk(row_key=b"test_1"), + self._make_chunk(row_key=b"test_2"), + ] + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + gen = table.read_rows_stream(query, operation_timeout=3) + results = [row for row in gen] + assert len(results) == 2 + assert results[0].row_key == b"test_1" + assert results[1].row_key == b"test_2" - mock_predicate = mock.Mock() - predicate_pb = {"predicate": "dict"} - mock_predicate._to_pb.return_value = predicate_pb - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "check_and_mutate_row" - ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True - ) - table.check_and_mutate_row( - b"row_key", mock_predicate, false_case_mutations=[mock.Mock()] - ) - kwargs = mock_gapic.call_args[1] - assert kwargs["predicate_filter"] == predicate_pb - assert mock_predicate._to_pb.call_count == 1 - assert kwargs["retry"] is None + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_read_rows_query_matches_request(self, include_app_profile): + from google.cloud.bigtable.data import RowRange + from google.cloud.bigtable.data.row_filters import PassAllFilter - def test_check_and_mutate_mutations_parsing(self): - """mutations objects should be converted to protos""" - from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - from google.cloud.bigtable.data.mutations import DeleteAllFromRow + app_profile_id = "app_profile_id" if include_app_profile else None + with self._make_table(app_profile_id=app_profile_id) as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream([]) + row_keys = [b"test_1", "test_2"] + row_ranges = RowRange("1start", "2end") + filter_ = PassAllFilter(True) + limit = 99 + query = ReadRowsQuery( + row_keys=row_keys, + row_ranges=row_ranges, + row_filter=filter_, + limit=limit, + ) + results = table.read_rows(query, operation_timeout=3) + assert len(results) == 0 + call_request = read_rows.call_args_list[0][0][0] + query_pb = query._to_pb(table) + assert call_request == query_pb - mutations = [mock.Mock() for _ in range(5)] - for idx, mutation in enumerate(mutations): - mutation._to_pb.return_value = f"fake {idx}" - mutations.append(DeleteAllFromRow()) - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "check_and_mutate_row" - ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True - ) - table.check_and_mutate_row( - b"row_key", - None, - true_case_mutations=mutations[0:2], - false_case_mutations=mutations[2:], - ) - kwargs = mock_gapic.call_args[1] - assert kwargs["true_mutations"] == ["fake 0", "fake 1"] - assert kwargs["false_mutations"] == [ - "fake 2", - "fake 3", - "fake 4", - DeleteAllFromRow()._to_pb(), - ] - assert all( - (mutation._to_pb.call_count == 1 for mutation in mutations[:5]) - ) - - -class TestMutateRow: - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) + @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) + def test_read_rows_timeout(self, operation_timeout): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + query = ReadRowsQuery() + chunks = [self._make_chunk(row_key=b"test_1")] + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=0.15 + ) + try: + table.read_rows(query, operation_timeout=operation_timeout) + except core_exceptions.DeadlineExceeded as e: + assert ( + e.message + == f"operation_timeout of {operation_timeout:0.1f}s exceeded" + ) @pytest.mark.parametrize( - "mutation_arg", - [ - mutations.SetCell("family", b"qualifier", b"value"), - mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=1234567890 - ), - mutations.DeleteRangeFromColumn("family", b"qualifier"), - mutations.DeleteAllFromFamily("family"), - mutations.DeleteAllFromRow(), - [mutations.SetCell("family", b"qualifier", b"value")], - [ - mutations.DeleteRangeFromColumn("family", b"qualifier"), - mutations.DeleteAllFromRow(), - ], - ], + "per_request_t, operation_t, expected_num", + [(0.05, 0.08, 2), (0.05, 0.14, 3), (0.05, 0.24, 5)], ) - def test_mutate_row(self, mutation_arg): - """Test mutations with no errors""" - expected_attempt_timeout = 19 - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.return_value = None - table.mutate_row( - "row_key", - mutation_arg, - attempt_timeout=expected_attempt_timeout, - ) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0].kwargs - assert ( - kwargs["table_name"] - == "projects/project/instances/instance/tables/table" - ) - assert kwargs["row_key"] == b"row_key" - formatted_mutations = ( - [mutation._to_pb() for mutation in mutation_arg] - if isinstance(mutation_arg, list) - else [mutation_arg._to_pb()] - ) - assert kwargs["mutations"] == formatted_mutations - assert kwargs["timeout"] == expected_attempt_timeout - assert kwargs["retry"] is None + def test_read_rows_attempt_timeout(self, per_request_t, operation_t, expected_num): + """Ensures that the attempt_timeout is respected and that the number of + requests is as expected. - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_mutate_row_retryable_errors(self, retryable_exception): - from google.api_core.exceptions import DeadlineExceeded + operation_timeout does not cancel the request, so we expect the number of + requests to be the ceiling of operation_timeout / attempt_timeout.""" from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(DeadlineExceeded) as e: - mutation = mutations.DeleteAllFromRow() - assert mutation.is_idempotent() is True - table.mutate_row("row_key", mutation, operation_timeout=0.01) - cause = e.value.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert isinstance(cause.exceptions[0], retryable_exception) + expected_last_timeout = operation_t - (expected_num - 1) * per_request_t + with mock.patch("random.uniform", side_effect=lambda a, b: 0): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=per_request_t + ) + query = ReadRowsQuery() + chunks = [core_exceptions.DeadlineExceeded("mock deadline")] + try: + table.read_rows( + query, + operation_timeout=operation_t, + attempt_timeout=per_request_t, + ) + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + if expected_num == 0: + assert retry_exc is None + else: + assert type(retry_exc) is RetryExceptionGroup + assert f"{expected_num} failed attempts" in str(retry_exc) + assert len(retry_exc.exceptions) == expected_num + for sub_exc in retry_exc.exceptions: + assert sub_exc.message == "mock deadline" + assert read_rows.call_count == expected_num + for _, call_kwargs in read_rows.call_args_list[:-1]: + assert call_kwargs["timeout"] == per_request_t + assert call_kwargs["retry"] is None + assert ( + abs( + read_rows.call_args_list[-1][1]["timeout"] + - expected_last_timeout + ) + < 0.05 + ) @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + "exc_type", + [ + core_exceptions.Aborted, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ], ) - def test_mutate_row_non_idempotent_retryable_errors(self, retryable_exception): - """Non-idempotent mutations should not be retried""" - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(retryable_exception): - mutation = mutations.SetCell( - "family", b"qualifier", b"value", -1 - ) - assert mutation.is_idempotent() is False - table.mutate_row("row_key", mutation, operation_timeout=0.2) + def test_read_rows_retryable_error(self, exc_type): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + try: + table.read_rows(query, operation_timeout=0.1) + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + root_cause = retry_exc.exceptions[0] + assert type(root_cause) is exc_type + assert root_cause == expected_error @pytest.mark.parametrize( - "non_retryable_exception", + "exc_type", [ - core_exceptions.OutOfRange, + core_exceptions.Cancelled, + core_exceptions.PreconditionFailed, core_exceptions.NotFound, - core_exceptions.FailedPrecondition, - RuntimeError, - ValueError, - core_exceptions.Aborted, + core_exceptions.PermissionDenied, + core_exceptions.Conflict, + core_exceptions.InternalServerError, + core_exceptions.TooManyRequests, + core_exceptions.ResourceExhausted, + InvalidChunk, ], ) - def test_mutate_row_non_retryable_errors(self, non_retryable_exception): - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = non_retryable_exception("mock") - with pytest.raises(non_retryable_exception): - mutation = mutations.SetCell( - "family", - b"qualifier", - b"value", - timestamp_micros=1234567890, - ) - assert mutation.is_idempotent() is True - table.mutate_row("row_key", mutation, operation_timeout=0.2) + def test_read_rows_non_retryable_error(self, exc_type): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + try: + table.read_rows(query, operation_timeout=0.1) + except exc_type as e: + assert e == expected_error - @pytest.mark.parametrize("include_app_profile", [True, False]) - def test_mutate_row_metadata(self, include_app_profile): - """request should attach metadata headers""" - profile = "profile" if include_app_profile else None - with self._make_client() as client: - with client.get_table("i", "t", app_profile_id=profile) as table: - with mock.patch.object( - client._gapic_client, "mutate_row", AsyncMock() - ) as read_rows: - table.mutate_row("rk", mock.Mock()) - kwargs = read_rows.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata + def test_read_rows_revise_request(self): + """Ensure that _revise_request is called between retries""" + from google.cloud.bigtable.data.exceptions import InvalidChunk + from google.cloud.bigtable_v2.types import RowSet - @pytest.mark.parametrize("mutations", [[], None]) - def test_mutate_row_no_mutations(self, mutations): + return_val = RowSet() + with mock.patch.object( + self._get_operation_class(), "_revise_request_rowset" + ) as revise_rowset: + revise_rowset.return_value = return_val + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + row_keys = [b"test_1", b"test_2", b"test_3"] + query = ReadRowsQuery(row_keys=row_keys) + chunks = [ + self._make_chunk(row_key=b"test_1"), + core_exceptions.Aborted("mock retryable error"), + ] + try: + table.read_rows(query) + except InvalidChunk: + revise_rowset.assert_called() + first_call_kwargs = revise_rowset.call_args_list[0].kwargs + assert first_call_kwargs["row_set"] == query._to_pb(table).rows + assert first_call_kwargs["last_seen_row_key"] == b"test_1" + revised_call = read_rows.call_args_list[1].args[0] + assert revised_call.rows == return_val + + def test_read_rows_default_timeouts(self): + """Ensure that the default timeouts are set on the read rows operation when not overridden""" + operation_timeout = 8 + attempt_timeout = 4 + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: + mock_op.side_effect = RuntimeError("mock error") + with self._make_table( + default_read_rows_operation_timeout=operation_timeout, + default_read_rows_attempt_timeout=attempt_timeout, + ) as table: + try: + table.read_rows(ReadRowsQuery()) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["attempt_timeout"] == attempt_timeout + + def test_read_rows_default_timeout_override(self): + """When timeouts are passed, they overwrite default values""" + operation_timeout = 8 + attempt_timeout = 4 + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: + mock_op.side_effect = RuntimeError("mock error") + with self._make_table( + default_operation_timeout=99, default_attempt_timeout=97 + ) as table: + try: + table.read_rows( + ReadRowsQuery(), + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + ) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["attempt_timeout"] == attempt_timeout + + def test_read_row(self): + """Test reading a single row""" with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as e: - table.mutate_row("key", mutations=mutations) - assert e.value.args[0] == "No mutations provided" + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + row = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert row == expected_result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + def test_read_row_w_filter(self): + """Test reading a single row with an added filter""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + mock_filter = mock.Mock() + expected_filter = {"filter": "mock filter"} + mock_filter._to_dict.return_value = expected_filter + row = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + row_filter=expected_filter, + ) + assert row == expected_result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + assert query.filter == expected_filter -class TestReadModifyWriteRow: - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) + def test_read_row_no_response(self): + """should return None if row does not exist""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = lambda *args, **kwargs: [] + expected_op_timeout = 8 + expected_req_timeout = 4 + result = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert result is None + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 @pytest.mark.parametrize( - "call_rules,expected_rules", - [ - ( - AppendValueRule("f", "c", b"1"), - [AppendValueRule("f", "c", b"1")._to_pb()], - ), - ( - [AppendValueRule("f", "c", b"1")], - [AppendValueRule("f", "c", b"1")._to_pb()], - ), - (IncrementRule("f", "c", 1), [IncrementRule("f", "c", 1)._to_pb()]), - ( - [AppendValueRule("f", "c", b"1"), IncrementRule("f", "c", 1)], - [ - AppendValueRule("f", "c", b"1")._to_pb(), - IncrementRule("f", "c", 1)._to_pb(), - ], - ), - ], + "return_value,expected_result", + [([], False), ([object()], True), ([object(), object()], True)], ) - def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): - """Test that the gapic call is called with given rules""" + def test_row_exists(self, return_value, expected_result): + """Test checking for row existence""" with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - table.read_modify_write_row("key", call_rules) - assert mock_gapic.call_count == 1 - found_kwargs = mock_gapic.call_args_list[0][1] - assert found_kwargs["rules"] == expected_rules - assert found_kwargs["retry"] is None + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = lambda *args, **kwargs: return_value + expected_op_timeout = 1 + expected_req_timeout = 2 + result = table.row_exists( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert expected_result == result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + expected_filter = { + "chain": { + "filters": [ + {"cells_per_row_limit_filter": 1}, + {"strip_value_transformer": True}, + ] + } + } + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + assert query.filter._to_dict() == expected_filter - @pytest.mark.parametrize("rules", [[], None]) - def test_read_modify_write_no_rules(self, rules): + +class TestReadRowsSharded: + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + def test_read_rows_sharded_empty_query(self): with self._make_client() as client: with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as e: - table.read_modify_write_row("key", rules=rules) - assert e.value.args[0] == "rules must contain at least one item" + with pytest.raises(ValueError) as exc: + table.read_rows_sharded([]) + assert "empty sharded_query" in str(exc.value) - def test_read_modify_write_call_defaults(self): - instance = "instance1" - table_id = "table1" - project = "project1" - row_key = "row_key1" - with self._make_client(project=project) as client: - with client.get_table(instance, table_id) as table: + def test_read_rows_sharded_multiple_queries(self): + """Test with multiple queries. Should return results from both""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - table.read_modify_write_row(row_key, mock.Mock()) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0][1] - assert ( - kwargs["table_name"] - == f"projects/{project}/instances/{instance}/tables/{table_id}" - ) - assert kwargs["app_profile_id"] is None - assert kwargs["row_key"] == row_key.encode() - assert kwargs["timeout"] > 1 - - def test_read_modify_write_call_overrides(self): - row_key = b"row_key1" - expected_timeout = 12345 - profile_id = "profile1" - with self._make_client() as client: - with client.get_table( - "instance", "table_id", app_profile_id=profile_id - ) as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - table.read_modify_write_row( - row_key, mock.Mock(), operation_timeout=expected_timeout + table.client._gapic_client, "read_rows" + ) as read_rows: + read_rows.side_effect = ( + lambda *args, **kwargs: TestReadRows._make_gapic_stream( + [ + TestReadRows._make_chunk(row_key=k) + for k in args[0].rows.row_keys + ] + ) ) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0][1] - assert kwargs["app_profile_id"] is profile_id - assert kwargs["row_key"] == row_key - assert kwargs["timeout"] == expected_timeout + query_1 = ReadRowsQuery(b"test_1") + query_2 = ReadRowsQuery(b"test_2") + result = table.read_rows_sharded([query_1, query_2]) + assert len(result) == 2 + assert result[0].row_key == b"test_1" + assert result[1].row_key == b"test_2" - def test_read_modify_write_string_key(self): - row_key = "string_row_key1" + @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24]) + def test_read_rows_sharded_multiple_queries_calls(self, n_queries): + """Each query should trigger a separate read_rows call""" with self._make_client() as client: - with client.get_table("instance", "table_id") as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - table.read_modify_write_row(row_key, mock.Mock()) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0][1] - assert kwargs["row_key"] == row_key.encode() + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + query_list = [ReadRowsQuery() for _ in range(n_queries)] + table.read_rows_sharded(query_list) + assert read_rows.call_count == n_queries - def test_read_modify_write_row_building(self): - """results from gapic call should be used to construct row""" - from google.cloud.bigtable.data.row import Row - from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse - from google.cloud.bigtable_v2.types import Row as RowPB + def test_read_rows_sharded_errors(self): + """Errors should be exposed as ShardedReadRowsExceptionGroups""" + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedQueryShardError - mock_response = ReadModifyWriteRowResponse(row=RowPB()) with self._make_client() as client: - with client.get_table("instance", "table_id") as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - with mock.patch.object(Row, "_from_pb") as constructor_mock: - mock_gapic.return_value = mock_response - table.read_modify_write_row("key", mock.Mock()) - assert constructor_mock.call_count == 1 - constructor_mock.assert_called_once_with(mock_response.row) - + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = RuntimeError("mock error") + query_1 = ReadRowsQuery(b"test_1") + query_2 = ReadRowsQuery(b"test_2") + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + table.read_rows_sharded([query_1, query_2]) + exc_group = exc.value + assert isinstance(exc_group, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == 2 + assert isinstance(exc.value.exceptions[0], FailedQueryShardError) + assert isinstance(exc.value.exceptions[0].__cause__, RuntimeError) + assert exc.value.exceptions[0].index == 0 + assert exc.value.exceptions[0].query == query_1 + assert isinstance(exc.value.exceptions[1], FailedQueryShardError) + assert isinstance(exc.value.exceptions[1].__cause__, RuntimeError) + assert exc.value.exceptions[1].index == 1 + assert exc.value.exceptions[1].query == query_2 -class TestReadRows: - """ - Tests for table.read_rows and related methods. - """ + def test_read_rows_sharded_concurrent(self): + """Ensure sharded requests are concurrent""" + import time - @staticmethod - def _get_operation_class(): - return _ReadRowsOperation + def mock_call(*args, **kwargs): + asyncio.sleep(0.1) + return [mock.Mock()] - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(10)] + start_time = time.monotonic() + result = table.read_rows_sharded(queries) + call_time = time.monotonic() - start_time + assert read_rows.call_count == 10 + assert len(result) == 10 + assert call_time < 0.2 - def _make_table(self, *args, **kwargs): - client_mock = mock.Mock() - client_mock._register_instance.side_effect = ( - lambda *args, **kwargs: CrossSync._Sync_Impl.yield_to_event_loop() - ) - client_mock._remove_instance_registration.side_effect = ( - lambda *args, **kwargs: CrossSync._Sync_Impl.yield_to_event_loop() - ) - kwargs["instance_id"] = kwargs.get( - "instance_id", args[0] if args else "instance" - ) - kwargs["table_id"] = kwargs.get( - "table_id", args[1] if len(args) > 1 else "table" - ) - client_mock._gapic_client.table_path.return_value = kwargs["table_id"] - client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] - return TestTable._get_target_class()(client_mock, *args, **kwargs) + def test_read_rows_sharded_concurrency_limit(self): + """Only 10 queries should be processed concurrently. Others should be queued - def _make_stats(self): - from google.cloud.bigtable_v2.types import RequestStats - from google.cloud.bigtable_v2.types import FullReadStatsView - from google.cloud.bigtable_v2.types import ReadIterationStats + Should start a new query as soon as previous finishes""" + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT - return RequestStats( - full_read_stats_view=FullReadStatsView( - read_iteration_stats=ReadIterationStats( - rows_seen_count=1, - rows_returned_count=2, - cells_seen_count=3, - cells_returned_count=4, - ) - ) - ) + assert _CONCURRENCY_LIMIT == 10 + num_queries = 15 + increment_time = 0.05 + max_time = increment_time * (_CONCURRENCY_LIMIT - 1) + rpc_times = [min(i * increment_time, max_time) for i in range(num_queries)] - @staticmethod - def _make_chunk(*args, **kwargs): - from google.cloud.bigtable_v2 import ReadRowsResponse + def mock_call(*args, **kwargs): + next_sleep = rpc_times.pop(0) + asyncio.sleep(next_sleep) + return [mock.Mock()] - kwargs["row_key"] = kwargs.get("row_key", b"row_key") - kwargs["family_name"] = kwargs.get("family_name", "family_name") - kwargs["qualifier"] = kwargs.get("qualifier", b"qualifier") - kwargs["value"] = kwargs.get("value", b"value") - kwargs["commit_row"] = kwargs.get("commit_row", True) - return ReadRowsResponse.CellChunk(*args, **kwargs) + starting_timeout = 10 + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(num_queries)] + table.read_rows_sharded(queries, operation_timeout=starting_timeout) + assert read_rows.call_count == num_queries + rpc_start_list = [ + starting_timeout - kwargs["operation_timeout"] + for _, kwargs in read_rows.call_args_list + ] + eps = 0.01 + assert all( + (rpc_start_list[i] < eps for i in range(_CONCURRENCY_LIMIT)) + ) + for i in range(num_queries - _CONCURRENCY_LIMIT): + idx = i + _CONCURRENCY_LIMIT + assert rpc_start_list[idx] - i * increment_time < eps - @staticmethod - def _make_gapic_stream( - chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0 - ): - from google.cloud.bigtable_v2 import ReadRowsResponse + def test_read_rows_sharded_expirary(self): + """If the operation times out before all shards complete, should raise + a ShardedReadRowsExceptionGroup""" + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded - class mock_stream: - def __init__(self, chunk_list, sleep_time): - self.chunk_list = chunk_list - self.idx = -1 - self.sleep_time = sleep_time + operation_timeout = 0.1 + num_queries = 15 + sleeps = [0] * _CONCURRENCY_LIMIT + [DeadlineExceeded("times up")] * ( + num_queries - _CONCURRENCY_LIMIT + ) - def __iter__(self): - return self + def mock_call(*args, **kwargs): + next_item = sleeps.pop(0) + if isinstance(next_item, Exception): + raise next_item + else: + asyncio.sleep(next_item) + return [mock.Mock()] - def __next__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - if sleep_time: - CrossSync._Sync_Impl.sleep(self.sleep_time) - chunk = self.chunk_list[self.idx] - if isinstance(chunk, Exception): - raise chunk - else: - return ReadRowsResponse(chunks=[chunk]) - raise StopIteration + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(num_queries)] + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + table.read_rows_sharded( + queries, operation_timeout=operation_timeout + ) + assert isinstance(exc.value, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == num_queries - _CONCURRENCY_LIMIT + assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT - def cancel(self): - pass + def test_read_rows_sharded_negative_batch_timeout(self): + """try to run with batch that starts after operation timeout - return mock_stream(chunk_list, sleep_time) + They should raise DeadlineExceeded errors""" + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded - def execute_fn(self, table, *args, **kwargs): - return table.read_rows(*args, **kwargs) + def mock_call(*args, **kwargs): + CrossSync._Sync_Impl.sleep(0.05) + return [mock.Mock()] - def test_read_rows(self): - query = ReadRowsQuery() - chunks = [ - self._make_chunk(row_key=b"test_1"), - self._make_chunk(row_key=b"test_2"), - ] - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks - ) - results = self.execute_fn(table, query, operation_timeout=3) - assert len(results) == 2 - assert results[0].row_key == b"test_1" - assert results[1].row_key == b"test_2" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(15)] + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + table.read_rows_sharded(queries, operation_timeout=0.01) + assert isinstance(exc.value, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == 5 + assert all( + ( + isinstance(e.__cause__, DeadlineExceeded) + for e in exc.value.exceptions + ) + ) - def test_read_rows_stream(self): - query = ReadRowsQuery() - chunks = [ - self._make_chunk(row_key=b"test_1"), - self._make_chunk(row_key=b"test_2"), - ] - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks - ) - gen = table.read_rows_stream(query, operation_timeout=3) - results = [row for row in gen] - assert len(results) == 2 - assert results[0].row_key == b"test_1" - assert results[1].row_key == b"test_2" - @pytest.mark.parametrize("include_app_profile", [True, False]) - def test_read_rows_query_matches_request(self, include_app_profile): - from google.cloud.bigtable.data import RowRange - from google.cloud.bigtable.data.row_filters import PassAllFilter +class TestSampleRowKeys: + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) - app_profile_id = "app_profile_id" if include_app_profile else None - with self._make_table(app_profile_id=app_profile_id) as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream([]) - row_keys = [b"test_1", "test_2"] - row_ranges = RowRange("1start", "2end") - filter_ = PassAllFilter(True) - limit = 99 - query = ReadRowsQuery( - row_keys=row_keys, - row_ranges=row_ranges, - row_filter=filter_, - limit=limit, - ) - results = table.read_rows(query, operation_timeout=3) - assert len(results) == 0 - call_request = read_rows.call_args_list[0][0][0] - query_pb = query._to_pb(table) - assert call_request == query_pb + def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): + from google.cloud.bigtable_v2.types import SampleRowKeysResponse - @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) - def test_read_rows_timeout(self, operation_timeout): - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - query = ReadRowsQuery() - chunks = [self._make_chunk(row_key=b"test_1")] - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=0.15 - ) - try: - table.read_rows(query, operation_timeout=operation_timeout) - except core_exceptions.DeadlineExceeded as e: - assert ( - e.message - == f"operation_timeout of {operation_timeout:0.1f}s exceeded" - ) + for value in sample_list: + yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) - @pytest.mark.parametrize( - "per_request_t, operation_t, expected_num", - [(0.05, 0.08, 2), (0.05, 0.14, 3), (0.05, 0.24, 5)], - ) - def test_read_rows_attempt_timeout(self, per_request_t, operation_t, expected_num): - """Ensures that the attempt_timeout is respected and that the number of - requests is as expected. + def test_sample_row_keys(self): + """Test that method returns the expected key samples""" + samples = [(b"test_1", 0), (b"test_2", 100), (b"test_3", 200)] + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", mock.Mock() + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream(samples) + result = table.sample_row_keys() + assert len(result) == 3 + assert all((isinstance(r, tuple) for r in result)) + assert all((isinstance(r[0], bytes) for r in result)) + assert all((isinstance(r[1], int) for r in result)) + assert result[0] == samples[0] + assert result[1] == samples[1] + assert result[2] == samples[2] - operation_timeout does not cancel the request, so we expect the number of - requests to be the ceiling of operation_timeout / attempt_timeout.""" - from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + def test_sample_row_keys_bad_timeout(self): + """should raise error if timeout is negative""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.sample_row_keys(operation_timeout=-1) + assert "operation_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + table.sample_row_keys(attempt_timeout=-1) + assert "attempt_timeout must be greater than 0" in str(e.value) - expected_last_timeout = operation_t - (expected_num - 1) * per_request_t - with mock.patch("random.uniform", side_effect=lambda a, b: 0): - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=per_request_t - ) - query = ReadRowsQuery() - chunks = [core_exceptions.DeadlineExceeded("mock deadline")] - try: - table.read_rows( - query, - operation_timeout=operation_t, - attempt_timeout=per_request_t, - ) - except core_exceptions.DeadlineExceeded as e: - retry_exc = e.__cause__ - if expected_num == 0: - assert retry_exc is None - else: - assert type(retry_exc) is RetryExceptionGroup - assert f"{expected_num} failed attempts" in str(retry_exc) - assert len(retry_exc.exceptions) == expected_num - for sub_exc in retry_exc.exceptions: - assert sub_exc.message == "mock deadline" - assert read_rows.call_count == expected_num - for _, call_kwargs in read_rows.call_args_list[:-1]: - assert call_kwargs["timeout"] == per_request_t - assert call_kwargs["retry"] is None - assert ( - abs( - read_rows.call_args_list[-1][1]["timeout"] - - expected_last_timeout - ) - < 0.05 - ) + def test_sample_row_keys_default_timeout(self): + """Should fallback to using table default operation_timeout""" + expected_timeout = 99 + with self._make_client() as client: + with client.get_table( + "i", + "t", + default_operation_timeout=expected_timeout, + default_attempt_timeout=expected_timeout, + ) as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", mock.Mock() + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream([]) + result = table.sample_row_keys() + _, kwargs = sample_row_keys.call_args + assert abs(kwargs["timeout"] - expected_timeout) < 0.1 + assert result == [] + assert kwargs["retry"] is None + + def test_sample_row_keys_gapic_params(self): + """make sure arguments are propagated to gapic call as expected""" + expected_timeout = 10 + expected_profile = "test1" + instance = "instance_name" + table_id = "my_table" + with self._make_client() as client: + with client.get_table( + instance, table_id, app_profile_id=expected_profile + ) as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", mock.Mock() + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream([]) + table.sample_row_keys(attempt_timeout=expected_timeout) + args, kwargs = sample_row_keys.call_args + assert len(args) == 0 + assert len(kwargs) == 5 + assert kwargs["timeout"] == expected_timeout + assert kwargs["app_profile_id"] == expected_profile + assert kwargs["table_name"] == table.table_name + assert kwargs["metadata"] is not None + assert kwargs["retry"] is None @pytest.mark.parametrize( - "exc_type", - [ - core_exceptions.Aborted, - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - ], + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], ) - def test_read_rows_retryable_error(self, exc_type): - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - [expected_error] - ) - query = ReadRowsQuery() - expected_error = exc_type("mock error") - try: - table.read_rows(query, operation_timeout=0.1) - except core_exceptions.DeadlineExceeded as e: - retry_exc = e.__cause__ - root_cause = retry_exc.exceptions[0] - assert type(root_cause) is exc_type - assert root_cause == expected_error + def test_sample_row_keys_retryable_errors(self, retryable_exception): + """retryable errors should be retried until timeout""" + from google.api_core.exceptions import DeadlineExceeded + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", mock.Mock() + ) as sample_row_keys: + sample_row_keys.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + table.sample_row_keys(operation_timeout=0.05) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert len(cause.exceptions) > 0 + assert isinstance(cause.exceptions[0], retryable_exception) @pytest.mark.parametrize( - "exc_type", + "non_retryable_exception", [ - core_exceptions.Cancelled, - core_exceptions.PreconditionFailed, + core_exceptions.OutOfRange, core_exceptions.NotFound, - core_exceptions.PermissionDenied, - core_exceptions.Conflict, - core_exceptions.InternalServerError, - core_exceptions.TooManyRequests, - core_exceptions.ResourceExhausted, - InvalidChunk, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + core_exceptions.Aborted, ], ) - def test_read_rows_non_retryable_error(self, exc_type): - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - [expected_error] - ) - query = ReadRowsQuery() - expected_error = exc_type("mock error") - try: - table.read_rows(query, operation_timeout=0.1) - except exc_type as e: - assert e == expected_error + def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): + """non-retryable errors should cause a raise""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "sample_row_keys", mock.Mock() + ) as sample_row_keys: + sample_row_keys.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + table.sample_row_keys() - def test_read_rows_revise_request(self): - """Ensure that _revise_request is called between retries""" - from google.cloud.bigtable.data.exceptions import InvalidChunk - from google.cloud.bigtable_v2.types import RowSet - return_val = RowSet() - with mock.patch.object( - self._get_operation_class(), "_revise_request_rowset" - ) as revise_rowset: - revise_rowset.return_value = return_val - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks - ) - row_keys = [b"test_1", b"test_2", b"test_3"] - query = ReadRowsQuery(row_keys=row_keys) - chunks = [ - self._make_chunk(row_key=b"test_1"), - core_exceptions.Aborted("mock retryable error"), - ] - try: - table.read_rows(query) - except InvalidChunk: - revise_rowset.assert_called() - first_call_kwargs = revise_rowset.call_args_list[0].kwargs - assert first_call_kwargs["row_set"] == query._to_pb(table).rows - assert first_call_kwargs["last_seen_row_key"] == b"test_1" - revised_call = read_rows.call_args_list[1].args[0] - assert revised_call.rows == return_val - - def test_read_rows_default_timeouts(self): - """Ensure that the default timeouts are set on the read rows operation when not overridden""" - operation_timeout = 8 - attempt_timeout = 4 - with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: - mock_op.side_effect = RuntimeError("mock error") - with self._make_table( - default_read_rows_operation_timeout=operation_timeout, - default_read_rows_attempt_timeout=attempt_timeout, - ) as table: - try: - table.read_rows(ReadRowsQuery()) - except RuntimeError: - pass - kwargs = mock_op.call_args_list[0].kwargs - assert kwargs["operation_timeout"] == operation_timeout - assert kwargs["attempt_timeout"] == attempt_timeout +class TestMutateRow: + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) - def test_read_rows_default_timeout_override(self): - """When timeouts are passed, they overwrite default values""" - operation_timeout = 8 - attempt_timeout = 4 - with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: - mock_op.side_effect = RuntimeError("mock error") - with self._make_table( - default_operation_timeout=99, default_attempt_timeout=97 - ) as table: - try: - table.read_rows( - ReadRowsQuery(), - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, + @pytest.mark.parametrize( + "mutation_arg", + [ + mutations.SetCell("family", b"qualifier", b"value"), + mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=1234567890 + ), + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromFamily("family"), + mutations.DeleteAllFromRow(), + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromRow(), + ], + ], + ) + def test_mutate_row(self, mutation_arg): + """Test mutations with no errors""" + expected_attempt_timeout = 19 + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.return_value = None + table.mutate_row( + "row_key", + mutation_arg, + attempt_timeout=expected_attempt_timeout, ) - except RuntimeError: - pass - kwargs = mock_op.call_args_list[0].kwargs - assert kwargs["operation_timeout"] == operation_timeout - assert kwargs["attempt_timeout"] == attempt_timeout + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0].kwargs + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["row_key"] == b"row_key" + formatted_mutations = ( + [mutation._to_pb() for mutation in mutation_arg] + if isinstance(mutation_arg, list) + else [mutation_arg._to_pb()] + ) + assert kwargs["mutations"] == formatted_mutations + assert kwargs["timeout"] == expected_attempt_timeout + assert kwargs["retry"] is None - def test_read_row(self): - """Test reading a single row""" - with self._make_client() as client: - table = client.get_table("instance", "table") - row_key = b"test_1" - with mock.patch.object(table, "read_rows") as read_rows: - expected_result = object() - read_rows.side_effect = lambda *args, **kwargs: [expected_result] - expected_op_timeout = 8 - expected_req_timeout = 4 - row = table.read_row( - row_key, - operation_timeout=expected_op_timeout, - attempt_timeout=expected_req_timeout, - ) - assert row == expected_result - assert read_rows.call_count == 1 - args, kwargs = read_rows.call_args_list[0] - assert kwargs["operation_timeout"] == expected_op_timeout - assert kwargs["attempt_timeout"] == expected_req_timeout - assert len(args) == 1 - assert isinstance(args[0], ReadRowsQuery) - query = args[0] - assert query.row_keys == [row_key] - assert query.row_ranges == [] - assert query.limit == 1 + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_mutate_row_retryable_errors(self, retryable_exception): + from google.api_core.exceptions import DeadlineExceeded + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - def test_read_row_w_filter(self): - """Test reading a single row with an added filter""" - with self._make_client() as client: - table = client.get_table("instance", "table") - row_key = b"test_1" - with mock.patch.object(table, "read_rows") as read_rows: - expected_result = object() - read_rows.side_effect = lambda *args, **kwargs: [expected_result] - expected_op_timeout = 8 - expected_req_timeout = 4 - mock_filter = mock.Mock() - expected_filter = {"filter": "mock filter"} - mock_filter._to_dict.return_value = expected_filter - row = table.read_row( - row_key, - operation_timeout=expected_op_timeout, - attempt_timeout=expected_req_timeout, - row_filter=expected_filter, - ) - assert row == expected_result - assert read_rows.call_count == 1 - args, kwargs = read_rows.call_args_list[0] - assert kwargs["operation_timeout"] == expected_op_timeout - assert kwargs["attempt_timeout"] == expected_req_timeout - assert len(args) == 1 - assert isinstance(args[0], ReadRowsQuery) - query = args[0] - assert query.row_keys == [row_key] - assert query.row_ranges == [] - assert query.limit == 1 - assert query.filter == expected_filter + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + mutation = mutations.DeleteAllFromRow() + assert mutation.is_idempotent() is True + table.mutate_row("row_key", mutation, operation_timeout=0.01) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) - def test_read_row_no_response(self): - """should return None if row does not exist""" - with self._make_client() as client: - table = client.get_table("instance", "table") - row_key = b"test_1" - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = lambda *args, **kwargs: [] - expected_op_timeout = 8 - expected_req_timeout = 4 - result = table.read_row( - row_key, - operation_timeout=expected_op_timeout, - attempt_timeout=expected_req_timeout, - ) - assert result is None - assert read_rows.call_count == 1 - args, kwargs = read_rows.call_args_list[0] - assert kwargs["operation_timeout"] == expected_op_timeout - assert kwargs["attempt_timeout"] == expected_req_timeout - assert isinstance(args[0], ReadRowsQuery) - query = args[0] - assert query.row_keys == [row_key] - assert query.row_ranges == [] - assert query.limit == 1 + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_mutate_row_non_idempotent_retryable_errors(self, retryable_exception): + """Non-idempotent mutations should not be retried""" + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(retryable_exception): + mutation = mutations.SetCell( + "family", b"qualifier", b"value", -1 + ) + assert mutation.is_idempotent() is False + table.mutate_row("row_key", mutation, operation_timeout=0.2) @pytest.mark.parametrize( - "return_value,expected_result", - [([], False), ([object()], True), ([object(), object()], True)], + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + core_exceptions.Aborted, + ], ) - def test_row_exists(self, return_value, expected_result): - """Test checking for row existence""" + def test_mutate_row_non_retryable_errors(self, non_retryable_exception): + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + mutation = mutations.SetCell( + "family", + b"qualifier", + b"value", + timestamp_micros=1234567890, + ) + assert mutation.is_idempotent() is True + table.mutate_row("row_key", mutation, operation_timeout=0.2) + + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_mutate_row_metadata(self, include_app_profile): + """request should attach metadata headers""" + profile = "profile" if include_app_profile else None with self._make_client() as client: - table = client.get_table("instance", "table") - row_key = b"test_1" - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = lambda *args, **kwargs: return_value - expected_op_timeout = 1 - expected_req_timeout = 2 - result = table.row_exists( - row_key, - operation_timeout=expected_op_timeout, - attempt_timeout=expected_req_timeout, - ) - assert expected_result == result - assert read_rows.call_count == 1 - args, kwargs = read_rows.call_args_list[0] - assert kwargs["operation_timeout"] == expected_op_timeout - assert kwargs["attempt_timeout"] == expected_req_timeout - assert isinstance(args[0], ReadRowsQuery) - expected_filter = { - "chain": { - "filters": [ - {"cells_per_row_limit_filter": 1}, - {"strip_value_transformer": True}, - ] - } - } - query = args[0] - assert query.row_keys == [row_key] - assert query.row_ranges == [] - assert query.limit == 1 - assert query.filter._to_dict() == expected_filter + with client.get_table("i", "t", app_profile_id=profile) as table: + with mock.patch.object( + client._gapic_client, "mutate_row", AsyncMock() + ) as read_rows: + table.mutate_row("rk", mock.Mock()) + kwargs = read_rows.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata + + @pytest.mark.parametrize("mutations", [[], None]) + def test_mutate_row_no_mutations(self, mutations): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.mutate_row("key", mutations=mutations) + assert e.value.args[0] == "No mutations provided" -class TestReadRowsSharded: +class TestBulkMutateRows: def _make_client(self, *args, **kwargs): return TestBigtableDataClient._make_client(*args, **kwargs) - def test_read_rows_sharded_empty_query(self): - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as exc: - table.read_rows_sharded([]) - assert "empty sharded_query" in str(exc.value) + def _mock_response(self, response_list): + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc import status_pb2 - def test_read_rows_sharded_multiple_queries(self): - """Test with multiple queries. Should return results from both""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - table.client._gapic_client, "read_rows" - ) as read_rows: - read_rows.side_effect = ( - lambda *args, **kwargs: TestReadRows._make_gapic_stream( - [ - TestReadRows._make_chunk(row_key=k) - for k in args[0].rows.row_keys - ] - ) + statuses = [] + for response in response_list: + if isinstance(response, core_exceptions.GoogleAPICallError): + statuses.append( + status_pb2.Status( + message=str(response), code=response.grpc_status_code.value[0] ) - query_1 = ReadRowsQuery(b"test_1") - query_2 = ReadRowsQuery(b"test_2") - result = table.read_rows_sharded([query_1, query_2]) - assert len(result) == 2 - assert result[0].row_key == b"test_1" - assert result[1].row_key == b"test_2" + ) + else: + statuses.append(status_pb2.Status(code=0)) + entries = [ + MutateRowsResponse.Entry(index=i, status=statuses[i]) + for i in range(len(response_list)) + ] - @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24]) - def test_read_rows_sharded_multiple_queries_calls(self, n_queries): - """Each query should trigger a separate read_rows call""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - query_list = [ReadRowsQuery() for _ in range(n_queries)] - table.read_rows_sharded(query_list) - assert read_rows.call_count == n_queries + def generator(): + yield MutateRowsResponse(entries=entries) - def test_read_rows_sharded_errors(self): - """Errors should be exposed as ShardedReadRowsExceptionGroups""" - from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup - from google.cloud.bigtable.data.exceptions import FailedQueryShardError + return generator() - with self._make_client() as client: + @pytest.mark.parametrize( + "mutation_arg", + [ + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=1234567890 + ) + ], + [mutations.DeleteRangeFromColumn("family", b"qualifier")], + [mutations.DeleteAllFromFamily("family")], + [mutations.DeleteAllFromRow()], + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromRow(), + ], + ], + ) + def test_bulk_mutate_rows(self, mutation_arg): + """Test mutations with no errors""" + expected_attempt_timeout = 19 + with self._make_client(project="project") as client: with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = RuntimeError("mock error") - query_1 = ReadRowsQuery(b"test_1") - query_2 = ReadRowsQuery(b"test_2") - with pytest.raises(ShardedReadRowsExceptionGroup) as exc: - table.read_rows_sharded([query_1, query_2]) - exc_group = exc.value - assert isinstance(exc_group, ShardedReadRowsExceptionGroup) - assert len(exc.value.exceptions) == 2 - assert isinstance(exc.value.exceptions[0], FailedQueryShardError) - assert isinstance(exc.value.exceptions[0].__cause__, RuntimeError) - assert exc.value.exceptions[0].index == 0 - assert exc.value.exceptions[0].query == query_1 - assert isinstance(exc.value.exceptions[1], FailedQueryShardError) - assert isinstance(exc.value.exceptions[1].__cause__, RuntimeError) - assert exc.value.exceptions[1].index == 1 - assert exc.value.exceptions[1].query == query_2 - - def test_read_rows_sharded_concurrent(self): - """Ensure sharded requests are concurrent""" - import time - - def mock_call(*args, **kwargs): - asyncio.sleep(0.1) - return [mock.Mock()] + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.return_value = self._mock_response([None]) + bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg) + table.bulk_mutate_rows( + [bulk_mutation], attempt_timeout=expected_attempt_timeout + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"] == [bulk_mutation._to_pb()] + assert kwargs["timeout"] == expected_attempt_timeout + assert kwargs["retry"] is None - with self._make_client() as client: + def test_bulk_mutate_rows_multiple_entries(self): + """Test mutations with no errors""" + with self._make_client(project="project") as client: with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = mock_call - queries = [ReadRowsQuery() for _ in range(10)] - start_time = time.monotonic() - result = table.read_rows_sharded(queries) - call_time = time.monotonic() - start_time - assert read_rows.call_count == 10 - assert len(result) == 10 - assert call_time < 0.2 - - def test_read_rows_sharded_concurrency_limit(self): - """Only 10 queries should be processed concurrently. Others should be queued + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.return_value = self._mock_response([None, None]) + mutation_list = [mutations.DeleteAllFromRow()] + entry_1 = mutations.RowMutationEntry(b"row_key_1", mutation_list) + entry_2 = mutations.RowMutationEntry(b"row_key_2", mutation_list) + table.bulk_mutate_rows([entry_1, entry_2]) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"][0] == entry_1._to_pb() + assert kwargs["entries"][1] == entry_2._to_pb() - Should start a new query as soon as previous finishes""" - from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT + @pytest.mark.parametrize( + "exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_rows_idempotent_mutation_error_retryable(self, exception): + """Individual idempotent mutations should be retried if they fail with a retryable error""" + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) - assert _CONCURRENCY_LIMIT == 10 - num_queries = 15 - increment_time = 0.05 - max_time = increment_time * (_CONCURRENCY_LIMIT - 1) - rpc_times = [min(i * increment_time, max_time) for i in range(num_queries)] + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], exception) + assert isinstance( + cause.exceptions[-1], core_exceptions.DeadlineExceeded + ) - def mock_call(*args, **kwargs): - next_sleep = rpc_times.pop(0) - asyncio.sleep(next_sleep) - return [mock.Mock()] + @pytest.mark.parametrize( + "exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + core_exceptions.Aborted, + ], + ) + def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable(self, exception): + """Individual idempotent mutations should not be retried if they fail with a non-retryable error""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) - starting_timeout = 10 - with self._make_client() as client: + with self._make_client(project="project") as client: with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = mock_call - queries = [ReadRowsQuery() for _ in range(num_queries)] - table.read_rows_sharded(queries, operation_timeout=starting_timeout) - assert read_rows.call_count == num_queries - rpc_start_list = [ - starting_timeout - kwargs["operation_timeout"] - for _, kwargs in read_rows.call_args_list - ] - eps = 0.01 - assert all( - (rpc_start_list[i] < eps for i in range(_CONCURRENCY_LIMIT)) + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] ) - for i in range(num_queries - _CONCURRENCY_LIMIT): - idx = i + _CONCURRENCY_LIMIT - assert rpc_start_list[idx] - i * increment_time < eps - - def test_read_rows_sharded_expirary(self): - """If the operation times out before all shards complete, should raise - a ShardedReadRowsExceptionGroup""" - from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT - from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup - from google.api_core.exceptions import DeadlineExceeded + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, exception) - operation_timeout = 0.1 - num_queries = 15 - sleeps = [0] * _CONCURRENCY_LIMIT + [DeadlineExceeded("times up")] * ( - num_queries - _CONCURRENCY_LIMIT + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_idempotent_retryable_request_errors(self, retryable_exception): + """Individual idempotent mutations should be retried if the request fails with a retryable error""" + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, ) - def mock_call(*args, **kwargs): - next_item = sleeps.pop(0) - if isinstance(next_item, Exception): - raise next_item - else: - asyncio.sleep(next_item) - return [mock.Mock()] - - with self._make_client() as client: + with self._make_client(project="project") as client: with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = mock_call - queries = [ReadRowsQuery() for _ in range(num_queries)] - with pytest.raises(ShardedReadRowsExceptionGroup) as exc: - table.read_rows_sharded( - queries, operation_timeout=operation_timeout + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 ) - assert isinstance(exc.value, ShardedReadRowsExceptionGroup) - assert len(exc.value.exceptions) == num_queries - _CONCURRENCY_LIMIT - assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT - - def test_read_rows_sharded_negative_batch_timeout(self): - """try to run with batch that starts after operation timeout - - They should raise DeadlineExceeded errors""" - from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup - from google.api_core.exceptions import DeadlineExceeded + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) - def mock_call(*args, **kwargs): - CrossSync._Sync_Impl.sleep(0.05) - return [mock.Mock()] + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_rows_non_idempotent_retryable_errors( + self, retryable_exception + ): + """Non-Idempotent mutations should never be retried""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) - with self._make_client() as client: + with self._make_client(project="project") as client: with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = mock_call - queries = [ReadRowsQuery() for _ in range(15)] - with pytest.raises(ShardedReadRowsExceptionGroup) as exc: - table.read_rows_sharded(queries, operation_timeout=0.01) - assert isinstance(exc.value, ShardedReadRowsExceptionGroup) - assert len(exc.value.exceptions) == 5 - assert all( - ( - isinstance(e.__cause__, DeadlineExceeded) - for e in exc.value.exceptions - ) + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [retryable_exception("mock")] ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", -1 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is False + table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, retryable_exception) + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + ], + ) + def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception): + """If the request fails with a non-retryable error, mutations should not be retried""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) -class TestSampleRowKeys: - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): - from google.cloud.bigtable_v2.types import SampleRowKeysResponse + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, non_retryable_exception) - for value in sample_list: - yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) + def test_bulk_mutate_error_index(self): + """Test partial failure, partial success. Errors should be associated with the correct index""" + from google.api_core.exceptions import ( + DeadlineExceeded, + ServiceUnavailable, + FailedPrecondition, + ) + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) - def test_sample_row_keys(self): - """Test that method returns the expected key samples""" - samples = [(b"test_1", 0), (b"test_2", 100), (b"test_3", 200)] - with self._make_client() as client: + with self._make_client(project="project") as client: with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() - ) as sample_row_keys: - sample_row_keys.return_value = self._make_gapic_stream(samples) - result = table.sample_row_keys() - assert len(result) == 3 - assert all((isinstance(r, tuple) for r in result)) - assert all((isinstance(r[0], bytes) for r in result)) - assert all((isinstance(r[1], int) for r in result)) - assert result[0] == samples[0] - assert result[1] == samples[1] - assert result[2] == samples[2] + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = [ + self._mock_response([None, ServiceUnavailable("mock"), None]), + self._mock_response([DeadlineExceeded("mock")]), + self._mock_response([FailedPrecondition("final")]), + ] + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entries = [ + mutations.RowMutationEntry( + f"row_key_{i}".encode(), [mutation] + ) + for i in range(3) + ] + assert mutation.is_idempotent() is True + table.bulk_mutate_rows(entries, operation_timeout=1000) + assert len(e.value.exceptions) == 1 + failed = e.value.exceptions[0] + assert isinstance(failed, FailedMutationEntryError) + assert failed.index == 1 + assert failed.entry == entries[1] + cause = failed.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert len(cause.exceptions) == 3 + assert isinstance(cause.exceptions[0], ServiceUnavailable) + assert isinstance(cause.exceptions[1], DeadlineExceeded) + assert isinstance(cause.exceptions[2], FailedPrecondition) - def test_sample_row_keys_bad_timeout(self): - """should raise error if timeout is negative""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as e: - table.sample_row_keys(operation_timeout=-1) - assert "operation_timeout must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - table.sample_row_keys(attempt_timeout=-1) - assert "attempt_timeout must be greater than 0" in str(e.value) + def test_bulk_mutate_error_recovery(self): + """If an error occurs, then resolves, no exception should be raised""" + from google.api_core.exceptions import DeadlineExceeded - def test_sample_row_keys_default_timeout(self): - """Should fallback to using table default operation_timeout""" - expected_timeout = 99 - with self._make_client() as client: - with client.get_table( - "i", - "t", - default_operation_timeout=expected_timeout, - default_attempt_timeout=expected_timeout, - ) as table: - with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() - ) as sample_row_keys: - sample_row_keys.return_value = self._make_gapic_stream([]) - result = table.sample_row_keys() - _, kwargs = sample_row_keys.call_args - assert abs(kwargs["timeout"] - expected_timeout) < 0.1 - assert result == [] - assert kwargs["retry"] is None + with self._make_client(project="project") as client: + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.side_effect = [ + self._mock_response([DeadlineExceeded("mock")]), + self._mock_response([None]), + ] + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entries = [ + mutations.RowMutationEntry(f"row_key_{i}".encode(), [mutation]) + for i in range(3) + ] + table.bulk_mutate_rows(entries, operation_timeout=1000) - def test_sample_row_keys_gapic_params(self): - """make sure arguments are propagated to gapic call as expected""" - expected_timeout = 10 - expected_profile = "test1" - instance = "instance_name" - table_id = "my_table" + +class TestCheckAndMutateRow: + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) + + @pytest.mark.parametrize("gapic_result", [True, False]) + def test_check_and_mutate(self, gapic_result): + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + app_profile = "app_profile_id" with self._make_client() as client: with client.get_table( - instance, table_id, app_profile_id=expected_profile + "instance", "table", app_profile_id=app_profile ) as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() - ) as sample_row_keys: - sample_row_keys.return_value = self._make_gapic_stream([]) - table.sample_row_keys(attempt_timeout=expected_timeout) - args, kwargs = sample_row_keys.call_args - assert len(args) == 0 - assert len(kwargs) == 5 - assert kwargs["timeout"] == expected_timeout - assert kwargs["app_profile_id"] == expected_profile + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=gapic_result + ) + row_key = b"row_key" + predicate = None + true_mutations = [mock.Mock()] + false_mutations = [mock.Mock(), mock.Mock()] + operation_timeout = 0.2 + found = table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutations, + false_case_mutations=false_mutations, + operation_timeout=operation_timeout, + ) + assert found == gapic_result + kwargs = mock_gapic.call_args[1] assert kwargs["table_name"] == table.table_name - assert kwargs["metadata"] is not None + assert kwargs["row_key"] == row_key + assert kwargs["predicate_filter"] == predicate + assert kwargs["true_mutations"] == [ + m._to_pb() for m in true_mutations + ] + assert kwargs["false_mutations"] == [ + m._to_pb() for m in false_mutations + ] + assert kwargs["app_profile_id"] == app_profile + assert kwargs["timeout"] == operation_timeout assert kwargs["retry"] is None - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_sample_row_keys_retryable_errors(self, retryable_exception): - """retryable errors should be retried until timeout""" - from google.api_core.exceptions import DeadlineExceeded - from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - + def test_check_and_mutate_bad_timeout(self): + """Should raise error if operation_timeout < 0""" with self._make_client() as client: with client.get_table("instance", "table") as table: - with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() - ) as sample_row_keys: - sample_row_keys.side_effect = retryable_exception("mock") - with pytest.raises(DeadlineExceeded) as e: - table.sample_row_keys(operation_timeout=0.05) - cause = e.value.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert len(cause.exceptions) > 0 - assert isinstance(cause.exceptions[0], retryable_exception) + with pytest.raises(ValueError) as e: + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=[mock.Mock()], + false_case_mutations=[], + operation_timeout=-1, + ) + assert str(e.value) == "operation_timeout must be greater than 0" + + def test_check_and_mutate_single_mutations(self): + """if single mutations are passed, they should be internally wrapped in a list""" + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - @pytest.mark.parametrize( - "non_retryable_exception", - [ - core_exceptions.OutOfRange, - core_exceptions.NotFound, - core_exceptions.FailedPrecondition, - RuntimeError, - ValueError, - core_exceptions.Aborted, - ], - ) - def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): - """non-retryable errors should cause a raise""" with self._make_client() as client: with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() - ) as sample_row_keys: - sample_row_keys.side_effect = non_retryable_exception("mock") - with pytest.raises(non_retryable_exception): - table.sample_row_keys() - - -class TestTable: - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + true_mutation = SetCell("family", b"qualifier", b"value") + false_mutation = SetCell("family", b"qualifier", b"value") + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["true_mutations"] == [true_mutation._to_pb()] + assert kwargs["false_mutations"] == [false_mutation._to_pb()] - @staticmethod - def _get_target_class(): - return Table + def test_check_and_mutate_predicate_object(self): + """predicate filter should be passed to gapic request""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - def test_table_ctor(self): - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + mock_predicate = mock.Mock() + predicate_pb = {"predicate": "dict"} + mock_predicate._to_pb.return_value = predicate_pb + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + table.check_and_mutate_row( + b"row_key", mock_predicate, false_case_mutations=[mock.Mock()] + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["predicate_filter"] == predicate_pb + assert mock_predicate._to_pb.call_count == 1 + assert kwargs["retry"] is None - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - expected_operation_timeout = 123 - expected_attempt_timeout = 12 - expected_read_rows_operation_timeout = 1.5 - expected_read_rows_attempt_timeout = 0.5 - expected_mutate_rows_operation_timeout = 2.5 - expected_mutate_rows_attempt_timeout = 0.75 - client = self._make_client() - assert not client._active_instances - table = self._get_target_class()( - client, - expected_instance_id, - expected_table_id, - expected_app_profile_id, - default_operation_timeout=expected_operation_timeout, - default_attempt_timeout=expected_attempt_timeout, - default_read_rows_operation_timeout=expected_read_rows_operation_timeout, - default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout, - default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, - default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, - ) - CrossSync._Sync_Impl.yield_to_event_loop() - assert table.table_id == expected_table_id - assert table.instance_id == expected_instance_id - assert table.app_profile_id == expected_app_profile_id - assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id - ) - assert instance_key in client._active_instances - assert client._instance_owners[instance_key] == {id(table)} - assert table.default_operation_timeout == expected_operation_timeout - assert table.default_attempt_timeout == expected_attempt_timeout - assert ( - table.default_read_rows_operation_timeout - == expected_read_rows_operation_timeout - ) - assert ( - table.default_read_rows_attempt_timeout - == expected_read_rows_attempt_timeout - ) - assert ( - table.default_mutate_rows_operation_timeout - == expected_mutate_rows_operation_timeout - ) - assert ( - table.default_mutate_rows_attempt_timeout - == expected_mutate_rows_attempt_timeout - ) - table._register_instance_future - assert table._register_instance_future.done() - assert not table._register_instance_future.cancelled() - assert table._register_instance_future.exception() is None - client.close() + def test_check_and_mutate_mutations_parsing(self): + """mutations objects should be converted to protos""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + from google.cloud.bigtable.data.mutations import DeleteAllFromRow - def test_table_ctor_defaults(self): - """should provide default timeout values and app_profile_id""" - expected_table_id = "table-id" - expected_instance_id = "instance-id" - client = self._make_client() - assert not client._active_instances - table = Table(client, expected_instance_id, expected_table_id) - CrossSync._Sync_Impl.yield_to_event_loop() - assert table.table_id == expected_table_id - assert table.instance_id == expected_instance_id - assert table.app_profile_id is None - assert table.client is client - assert table.default_operation_timeout == 60 - assert table.default_read_rows_operation_timeout == 600 - assert table.default_mutate_rows_operation_timeout == 600 - assert table.default_attempt_timeout == 20 - assert table.default_read_rows_attempt_timeout == 20 - assert table.default_mutate_rows_attempt_timeout == 60 - client.close() + mutations = [mock.Mock() for _ in range(5)] + for idx, mutation in enumerate(mutations): + mutation._to_pb.return_value = f"fake {idx}" + mutations.append(DeleteAllFromRow()) + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=mutations[0:2], + false_case_mutations=mutations[2:], + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["true_mutations"] == ["fake 0", "fake 1"] + assert kwargs["false_mutations"] == [ + "fake 2", + "fake 3", + "fake 4", + DeleteAllFromRow()._to_pb(), + ] + assert all( + (mutation._to_pb.call_count == 1 for mutation in mutations[:5]) + ) - def test_table_ctor_invalid_timeout_values(self): - """bad timeout values should raise ValueError""" - client = self._make_client() - timeout_pairs = [ - ("default_operation_timeout", "default_attempt_timeout"), - ( - "default_read_rows_operation_timeout", - "default_read_rows_attempt_timeout", - ), - ( - "default_mutate_rows_operation_timeout", - "default_mutate_rows_attempt_timeout", - ), - ] - for operation_timeout, attempt_timeout in timeout_pairs: - with pytest.raises(ValueError) as e: - Table(client, "", "", **{attempt_timeout: -1}) - assert "attempt_timeout must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - Table(client, "", "", **{operation_timeout: -1}) - assert "operation_timeout must be greater than 0" in str(e.value) - client.close() + +class TestReadModifyWriteRow: + def _make_client(self, *args, **kwargs): + return TestBigtableDataClient._make_client(*args, **kwargs) @pytest.mark.parametrize( - "fn_name,fn_args,is_stream,extra_retryables", - [ - ("read_rows_stream", (ReadRowsQuery(),), True, ()), - ("read_rows", (ReadRowsQuery(),), True, ()), - ("read_row", (b"row_key",), True, ()), - ("read_rows_sharded", ([ReadRowsQuery()],), True, ()), - ("row_exists", (b"row_key",), True, ()), - ("sample_row_keys", (), False, ()), - ("mutate_row", (b"row_key", [mock.Mock()]), False, ()), - ( - "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), - False, - (_MutateRowsIncomplete,), - ), - ], - ) - @pytest.mark.parametrize( - "input_retryables,expected_retryables", + "call_rules,expected_rules", [ ( - TABLE_DEFAULT.READ_ROWS, - [ - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - core_exceptions.Aborted, - ], + AppendValueRule("f", "c", b"1"), + [AppendValueRule("f", "c", b"1")._to_pb()], ), ( - TABLE_DEFAULT.DEFAULT, - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + [AppendValueRule("f", "c", b"1")], + [AppendValueRule("f", "c", b"1")._to_pb()], ), + (IncrementRule("f", "c", 1), [IncrementRule("f", "c", 1)._to_pb()]), ( - TABLE_DEFAULT.MUTATE_ROWS, - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + [AppendValueRule("f", "c", b"1"), IncrementRule("f", "c", 1)], + [ + AppendValueRule("f", "c", b"1")._to_pb(), + IncrementRule("f", "c", 1)._to_pb(), + ], ), - ([], []), - ([4], [core_exceptions.DeadlineExceeded]), ], ) - def test_customizable_retryable_errors( - self, - input_retryables, - expected_retryables, - fn_name, - fn_args, - is_stream, - extra_retryables, - ): - """Test that retryable functions support user-configurable arguments, and that the configured retryables are passed - down to the gapic layer.""" - retry_fn = "retry_target" - if is_stream: - retry_fn += "_stream" - if CrossSync._Sync_Impl.is_async: - retry_fn = f"CrossSync.{retry_fn}" - else: - retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" - with mock.patch( - f"google.cloud.bigtable.data._sync.cross_sync.{retry_fn}" - ) as retry_fn_mock: - with self._make_client() as client: - table = client.get_table("instance-id", "table-id") - expected_predicate = lambda a: a in expected_retryables - retry_fn_mock.side_effect = RuntimeError("stop early") - with mock.patch( - "google.api_core.retry.if_exception_type" - ) as predicate_builder_mock: - predicate_builder_mock.return_value = expected_predicate - with pytest.raises(Exception): - test_fn = table.__getattribute__(fn_name) - test_fn(*fn_args, retryable_errors=input_retryables) - predicate_builder_mock.assert_called_once_with( - *expected_retryables, *extra_retryables + def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): + """Test that the gapic call is called with given rules""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row("key", call_rules) + assert mock_gapic.call_count == 1 + found_kwargs = mock_gapic.call_args_list[0][1] + assert found_kwargs["rules"] == expected_rules + assert found_kwargs["retry"] is None + + @pytest.mark.parametrize("rules", [[], None]) + def test_read_modify_write_no_rules(self, rules): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.read_modify_write_row("key", rules=rules) + assert e.value.args[0] == "rules must contain at least one item" + + def test_read_modify_write_call_defaults(self): + instance = "instance1" + table_id = "table1" + project = "project1" + row_key = "row_key1" + with self._make_client(project=project) as client: + with client.get_table(instance, table_id) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert ( + kwargs["table_name"] + == f"projects/{project}/instances/{instance}/tables/{table_id}" ) - retry_call_args = retry_fn_mock.call_args_list[0].args - assert retry_call_args[1] is expected_predicate + assert kwargs["app_profile_id"] is None + assert kwargs["row_key"] == row_key.encode() + assert kwargs["timeout"] > 1 - @pytest.mark.parametrize( - "fn_name,fn_args,gapic_fn", - [ - ("read_rows_stream", (ReadRowsQuery(),), "read_rows"), - ("read_rows", (ReadRowsQuery(),), "read_rows"), - ("read_row", (b"row_key",), "read_rows"), - ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), - ("row_exists", (b"row_key",), "read_rows"), - ("sample_row_keys", (), "sample_row_keys"), - ("mutate_row", (b"row_key", [mock.Mock()]), "mutate_row"), - ( - "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), - "mutate_rows", - ), - ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), - ( - "read_modify_write_row", - (b"row_key", mock.Mock()), - "read_modify_write_row", - ), - ], - ) - @pytest.mark.parametrize("include_app_profile", [True, False]) - def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): - """check that all requests attach proper metadata headers""" - profile = "profile" if include_app_profile else None - with mock.patch.object(BigtableClient, gapic_fn, mock.Mock()) as gapic_mock: - gapic_mock.side_effect = RuntimeError("stop early") - with self._make_client() as client: - table = Table(client, "instance-id", "table-id", profile) - try: - test_fn = table.__getattribute__(fn_name) - maybe_stream = test_fn(*fn_args) - [i for i in maybe_stream] - except Exception: - pass - kwargs = gapic_mock.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata + def test_read_modify_write_call_overrides(self): + row_key = b"row_key1" + expected_timeout = 12345 + profile_id = "profile1" + with self._make_client() as client: + with client.get_table( + "instance", "table_id", app_profile_id=profile_id + ) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row( + row_key, mock.Mock(), operation_timeout=expected_timeout + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["app_profile_id"] is profile_id + assert kwargs["row_key"] == row_key + assert kwargs["timeout"] == expected_timeout + + def test_read_modify_write_string_key(self): + row_key = "string_row_key1" + with self._make_client() as client: + with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["row_key"] == row_key.encode() + + def test_read_modify_write_row_building(self): + """results from gapic call should be used to construct row""" + from google.cloud.bigtable.data.row import Row + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + from google.cloud.bigtable_v2.types import Row as RowPB + + mock_response = ReadModifyWriteRowResponse(row=RowPB()) + with self._make_client() as client: + with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + with mock.patch.object(Row, "_from_pb") as constructor_mock: + mock_gapic.return_value = mock_response + table.read_modify_write_row("key", mock.Mock()) + assert constructor_mock.call_count == 1 + constructor_mock.assert_called_once_with(mock_response.row) diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py index 9affe0d7a..20e160565 100644 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -30,6 +30,244 @@ from mock import AsyncMock +class Test_FlowControl: + @staticmethod + def _target_class(): + from google.cloud.bigtable.data._async.mutations_batcher import ( + _FlowControlAsync, + ) + + return _FlowControlAsync + + def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): + return self._target_class()(max_mutation_count, max_mutation_bytes) + + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def test_ctor(self): + max_mutation_count = 9 + max_mutation_bytes = 19 + instance = self._make_one(max_mutation_count, max_mutation_bytes) + assert instance._max_mutation_count == max_mutation_count + assert instance._max_mutation_bytes == max_mutation_bytes + assert instance._in_flight_mutation_count == 0 + assert instance._in_flight_mutation_bytes == 0 + assert isinstance(instance._capacity_condition, asyncio.Condition) + + def test_ctor_invalid_values(self): + """Test that values are positive, and fit within expected limits""" + with pytest.raises(ValueError) as e: + self._make_one(0, 1) + assert "max_mutation_count must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._make_one(1, 0) + assert "max_mutation_bytes must be greater than 0" in str(e.value) + + @pytest.mark.parametrize( + "max_count,max_size,existing_count,existing_size,new_count,new_size,expected", + [ + (1, 1, 0, 0, 0, 0, True), + (1, 1, 1, 1, 1, 1, False), + (10, 10, 0, 0, 0, 0, True), + (10, 10, 0, 0, 9, 9, True), + (10, 10, 0, 0, 11, 9, True), + (10, 10, 0, 1, 11, 9, True), + (10, 10, 1, 0, 11, 9, False), + (10, 10, 0, 0, 9, 11, True), + (10, 10, 1, 0, 9, 11, True), + (10, 10, 0, 1, 9, 11, False), + (10, 1, 0, 0, 1, 0, True), + (1, 10, 0, 0, 0, 8, True), + (float("inf"), float("inf"), 0, 0, 10000000000.0, 10000000000.0, True), + (8, 8, 0, 0, 10000000000.0, 10000000000.0, True), + (12, 12, 6, 6, 5, 5, True), + (12, 12, 5, 5, 6, 6, True), + (12, 12, 6, 6, 6, 6, True), + (12, 12, 6, 6, 7, 7, False), + (12, 12, 0, 0, 13, 13, True), + (12, 12, 12, 0, 0, 13, True), + (12, 12, 0, 12, 13, 0, True), + (12, 12, 1, 1, 13, 13, False), + (12, 12, 1, 1, 0, 13, False), + (12, 12, 1, 1, 13, 0, False), + ], + ) + def test__has_capacity( + self, + max_count, + max_size, + existing_count, + existing_size, + new_count, + new_size, + expected, + ): + """_has_capacity should return True if the new mutation will will not exceed the max count or size""" + instance = self._make_one(max_count, max_size) + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + assert instance._has_capacity(new_count, new_size) == expected + + @pytest.mark.parametrize( + "existing_count,existing_size,added_count,added_size,new_count,new_size", + [ + (0, 0, 0, 0, 0, 0), + (2, 2, 1, 1, 1, 1), + (2, 0, 1, 0, 1, 0), + (0, 2, 0, 1, 0, 1), + (10, 10, 0, 0, 10, 10), + (10, 10, 5, 5, 5, 5), + (0, 0, 1, 1, -1, -1), + ], + ) + def test_remove_from_flow_value_update( + self, + existing_count, + existing_size, + added_count, + added_size, + new_count, + new_size, + ): + """completed mutations should lower the inflight values""" + instance = self._make_one() + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + mutation = self._make_mutation(added_count, added_size) + instance.remove_from_flow(mutation) + assert instance._in_flight_mutation_count == new_count + assert instance._in_flight_mutation_bytes == new_size + + def test__remove_from_flow_unlock(self): + """capacity condition should notify after mutation is complete""" + import inspect + + instance = self._make_one(10, 10) + instance._in_flight_mutation_count = 10 + instance._in_flight_mutation_bytes = 10 + + def task_routine(): + with instance._capacity_condition: + instance._capacity_condition.wait_for( + lambda: instance._has_capacity(1, 1) + ) + + if inspect.iscoroutinefunction(task_routine): + task = asyncio.create_task(task_routine()) + task_alive = lambda: not task.done() + else: + import threading + + thread = threading.Thread(target=task_routine) + thread.start() + task_alive = thread.is_alive + asyncio.sleep(0.05) + assert task_alive() is True + mutation = self._make_mutation(count=0, size=5) + instance.remove_from_flow([mutation]) + asyncio.sleep(0.05) + assert instance._in_flight_mutation_count == 10 + assert instance._in_flight_mutation_bytes == 5 + assert task_alive() is True + instance._in_flight_mutation_bytes = 10 + mutation = self._make_mutation(count=5, size=0) + instance.remove_from_flow([mutation]) + asyncio.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 10 + assert task_alive() is True + instance._in_flight_mutation_count = 10 + mutation = self._make_mutation(count=5, size=5) + instance.remove_from_flow([mutation]) + asyncio.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 5 + assert task_alive() is False + + @pytest.mark.parametrize( + "mutations,count_cap,size_cap,expected_results", + [ + ([(5, 5), (1, 1), (1, 1)], 10, 10, [[(5, 5), (1, 1), (1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 1, 1, [[(1, 1)], [(1, 1)], [(1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 2, 10, [[(1, 1), (1, 1)], [(1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 10, 2, [[(1, 1), (1, 1)], [(1, 1)]]), + ( + [(1, 1), (5, 5), (4, 1), (1, 4), (1, 1)], + 5, + 5, + [[(1, 1)], [(5, 5)], [(4, 1), (1, 4)], [(1, 1)]], + ), + ], + ) + def test_add_to_flow(self, mutations, count_cap, size_cap, expected_results): + """Test batching with various flow control settings""" + mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] + instance = self._make_one(count_cap, size_cap) + i = 0 + for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + assert len(batch[j].mutations) == expected_batch[j][0] + assert batch[j].size() == expected_batch[j][1] + instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + @pytest.mark.parametrize( + "mutations,max_limit,expected_results", + [ + ([(1, 1)] * 11, 10, [[(1, 1)] * 10, [(1, 1)]]), + ([(1, 1)] * 10, 1, [[(1, 1)] for _ in range(10)]), + ([(1, 1)] * 10, 2, [[(1, 1), (1, 1)] for _ in range(5)]), + ], + ) + def test_add_to_flow_max_mutation_limits( + self, mutations, max_limit, expected_results + ): + """Test flow control running up against the max API limit + Should submit request early, even if the flow control has room for more""" + async_patch = mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + max_limit, + ) + sync_patch = mock.patch( + "google.cloud.bigtable.data.mutations._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + max_limit, + ) + with async_patch, sync_patch: + mutation_objs = [ + self._make_mutation(count=m[0], size=m[1]) for m in mutations + ] + instance = self._make_one(float("inf"), float("inf")) + i = 0 + for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + assert len(batch[j].mutations) == expected_batch[j][0] + assert batch[j].size() == expected_batch[j][1] + instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + def test_add_to_flow_oversize(self): + """mutations over the flow control limits should still be accepted""" + instance = self._make_one(2, 3) + large_size_mutation = self._make_mutation(count=1, size=10) + large_count_mutation = self._make_mutation(count=10, size=1) + results = [out for out in instance.add_to_flow([large_size_mutation])] + assert len(results) == 1 + instance.remove_from_flow(results[0]) + count_results = [out for out in instance.add_to_flow(large_count_mutation)] + assert len(count_results) == 1 + + class TestMutationsBatcher: def _get_target_class(self): from google.cloud.bigtable.data._async.mutations_batcher import ( @@ -867,241 +1105,3 @@ def test_customizable_retryable_errors(self, input_retryables, expected_retryabl ) retry_call_args = retry_fn_mock.call_args_list[0].args assert retry_call_args[1] is expected_predicate - - -class Test_FlowControl: - @staticmethod - def _target_class(): - from google.cloud.bigtable.data._async.mutations_batcher import ( - _FlowControlAsync, - ) - - return _FlowControlAsync - - def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): - return self._target_class()(max_mutation_count, max_mutation_bytes) - - @staticmethod - def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation - - def test_ctor(self): - max_mutation_count = 9 - max_mutation_bytes = 19 - instance = self._make_one(max_mutation_count, max_mutation_bytes) - assert instance._max_mutation_count == max_mutation_count - assert instance._max_mutation_bytes == max_mutation_bytes - assert instance._in_flight_mutation_count == 0 - assert instance._in_flight_mutation_bytes == 0 - assert isinstance(instance._capacity_condition, asyncio.Condition) - - def test_ctor_invalid_values(self): - """Test that values are positive, and fit within expected limits""" - with pytest.raises(ValueError) as e: - self._make_one(0, 1) - assert "max_mutation_count must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - self._make_one(1, 0) - assert "max_mutation_bytes must be greater than 0" in str(e.value) - - @pytest.mark.parametrize( - "max_count,max_size,existing_count,existing_size,new_count,new_size,expected", - [ - (1, 1, 0, 0, 0, 0, True), - (1, 1, 1, 1, 1, 1, False), - (10, 10, 0, 0, 0, 0, True), - (10, 10, 0, 0, 9, 9, True), - (10, 10, 0, 0, 11, 9, True), - (10, 10, 0, 1, 11, 9, True), - (10, 10, 1, 0, 11, 9, False), - (10, 10, 0, 0, 9, 11, True), - (10, 10, 1, 0, 9, 11, True), - (10, 10, 0, 1, 9, 11, False), - (10, 1, 0, 0, 1, 0, True), - (1, 10, 0, 0, 0, 8, True), - (float("inf"), float("inf"), 0, 0, 10000000000.0, 10000000000.0, True), - (8, 8, 0, 0, 10000000000.0, 10000000000.0, True), - (12, 12, 6, 6, 5, 5, True), - (12, 12, 5, 5, 6, 6, True), - (12, 12, 6, 6, 6, 6, True), - (12, 12, 6, 6, 7, 7, False), - (12, 12, 0, 0, 13, 13, True), - (12, 12, 12, 0, 0, 13, True), - (12, 12, 0, 12, 13, 0, True), - (12, 12, 1, 1, 13, 13, False), - (12, 12, 1, 1, 0, 13, False), - (12, 12, 1, 1, 13, 0, False), - ], - ) - def test__has_capacity( - self, - max_count, - max_size, - existing_count, - existing_size, - new_count, - new_size, - expected, - ): - """_has_capacity should return True if the new mutation will will not exceed the max count or size""" - instance = self._make_one(max_count, max_size) - instance._in_flight_mutation_count = existing_count - instance._in_flight_mutation_bytes = existing_size - assert instance._has_capacity(new_count, new_size) == expected - - @pytest.mark.parametrize( - "existing_count,existing_size,added_count,added_size,new_count,new_size", - [ - (0, 0, 0, 0, 0, 0), - (2, 2, 1, 1, 1, 1), - (2, 0, 1, 0, 1, 0), - (0, 2, 0, 1, 0, 1), - (10, 10, 0, 0, 10, 10), - (10, 10, 5, 5, 5, 5), - (0, 0, 1, 1, -1, -1), - ], - ) - def test_remove_from_flow_value_update( - self, - existing_count, - existing_size, - added_count, - added_size, - new_count, - new_size, - ): - """completed mutations should lower the inflight values""" - instance = self._make_one() - instance._in_flight_mutation_count = existing_count - instance._in_flight_mutation_bytes = existing_size - mutation = self._make_mutation(added_count, added_size) - instance.remove_from_flow(mutation) - assert instance._in_flight_mutation_count == new_count - assert instance._in_flight_mutation_bytes == new_size - - def test__remove_from_flow_unlock(self): - """capacity condition should notify after mutation is complete""" - import inspect - - instance = self._make_one(10, 10) - instance._in_flight_mutation_count = 10 - instance._in_flight_mutation_bytes = 10 - - def task_routine(): - with instance._capacity_condition: - instance._capacity_condition.wait_for( - lambda: instance._has_capacity(1, 1) - ) - - if inspect.iscoroutinefunction(task_routine): - task = asyncio.create_task(task_routine()) - task_alive = lambda: not task.done() - else: - import threading - - thread = threading.Thread(target=task_routine) - thread.start() - task_alive = thread.is_alive - asyncio.sleep(0.05) - assert task_alive() is True - mutation = self._make_mutation(count=0, size=5) - instance.remove_from_flow([mutation]) - asyncio.sleep(0.05) - assert instance._in_flight_mutation_count == 10 - assert instance._in_flight_mutation_bytes == 5 - assert task_alive() is True - instance._in_flight_mutation_bytes = 10 - mutation = self._make_mutation(count=5, size=0) - instance.remove_from_flow([mutation]) - asyncio.sleep(0.05) - assert instance._in_flight_mutation_count == 5 - assert instance._in_flight_mutation_bytes == 10 - assert task_alive() is True - instance._in_flight_mutation_count = 10 - mutation = self._make_mutation(count=5, size=5) - instance.remove_from_flow([mutation]) - asyncio.sleep(0.05) - assert instance._in_flight_mutation_count == 5 - assert instance._in_flight_mutation_bytes == 5 - assert task_alive() is False - - @pytest.mark.parametrize( - "mutations,count_cap,size_cap,expected_results", - [ - ([(5, 5), (1, 1), (1, 1)], 10, 10, [[(5, 5), (1, 1), (1, 1)]]), - ([(1, 1), (1, 1), (1, 1)], 1, 1, [[(1, 1)], [(1, 1)], [(1, 1)]]), - ([(1, 1), (1, 1), (1, 1)], 2, 10, [[(1, 1), (1, 1)], [(1, 1)]]), - ([(1, 1), (1, 1), (1, 1)], 10, 2, [[(1, 1), (1, 1)], [(1, 1)]]), - ( - [(1, 1), (5, 5), (4, 1), (1, 4), (1, 1)], - 5, - 5, - [[(1, 1)], [(5, 5)], [(4, 1), (1, 4)], [(1, 1)]], - ), - ], - ) - def test_add_to_flow(self, mutations, count_cap, size_cap, expected_results): - """Test batching with various flow control settings""" - mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] - instance = self._make_one(count_cap, size_cap) - i = 0 - for batch in instance.add_to_flow(mutation_objs): - expected_batch = expected_results[i] - assert len(batch) == len(expected_batch) - for j in range(len(expected_batch)): - assert len(batch[j].mutations) == expected_batch[j][0] - assert batch[j].size() == expected_batch[j][1] - instance.remove_from_flow(batch) - i += 1 - assert i == len(expected_results) - - @pytest.mark.parametrize( - "mutations,max_limit,expected_results", - [ - ([(1, 1)] * 11, 10, [[(1, 1)] * 10, [(1, 1)]]), - ([(1, 1)] * 10, 1, [[(1, 1)] for _ in range(10)]), - ([(1, 1)] * 10, 2, [[(1, 1), (1, 1)] for _ in range(5)]), - ], - ) - def test_add_to_flow_max_mutation_limits( - self, mutations, max_limit, expected_results - ): - """Test flow control running up against the max API limit - Should submit request early, even if the flow control has room for more""" - async_patch = mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", - max_limit, - ) - sync_patch = mock.patch( - "google.cloud.bigtable.data.mutations._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", - max_limit, - ) - with async_patch, sync_patch: - mutation_objs = [ - self._make_mutation(count=m[0], size=m[1]) for m in mutations - ] - instance = self._make_one(float("inf"), float("inf")) - i = 0 - for batch in instance.add_to_flow(mutation_objs): - expected_batch = expected_results[i] - assert len(batch) == len(expected_batch) - for j in range(len(expected_batch)): - assert len(batch[j].mutations) == expected_batch[j][0] - assert batch[j].size() == expected_batch[j][1] - instance.remove_from_flow(batch) - i += 1 - assert i == len(expected_results) - - def test_add_to_flow_oversize(self): - """mutations over the flow control limits should still be accepted""" - instance = self._make_one(2, 3) - large_size_mutation = self._make_mutation(count=1, size=10) - large_count_mutation = self._make_mutation(count=10, size=1) - results = [out for out in instance.add_to_flow([large_size_mutation])] - assert len(results) == 1 - instance.remove_from_flow(results[0]) - count_results = [out for out in instance.add_to_flow(large_count_mutation)] - assert len(count_results) == 1 From e93f2acd2985b9861077a6961c9c0c5d3d89abfc Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 17:14:29 -0600 Subject: [PATCH 126/360] got test_client passing again --- tests/unit/data/_async/test_client.py | 8 +++++++- tests/unit/data/_async/test_read_rows_acceptance.py | 3 +++ tests/unit/data/_sync/__init__.py | 0 tests/unit/data/_sync/test_client.py | 8 +++++++- tests/unit/data/_sync/test_read_rows_acceptance.py | 5 +++++ 5 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 tests/unit/data/_sync/__init__.py diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 5f6af8be8..3f729d436 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1517,7 +1517,13 @@ def __init__(self, chunk_list, sleep_time): def __aiter__(self): return self + def __iter__(self): + return self + async def __anext__(self): + return self.__next__() + + def __next__(self): self.idx += 1 if len(self.chunk_list) > self.idx: if sleep_time: @@ -1527,7 +1533,7 @@ async def __anext__(self): raise chunk else: return ReadRowsResponse(chunks=[chunk]) - raise StopAsyncIteration + raise CrossSync.StopIteration def cancel(self): pass diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index c8b49bdab..29e2344c9 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -52,6 +52,9 @@ class TestFile(proto.Message): __test__ = False read_rows_tests = proto.RepeatedField(proto.MESSAGE, number=1, message=ReadRowsTest) +if not CrossSync.is_async: + from .._async.test_read_rows_acceptance import ReadRowsTest + from .._async.test_read_rows_acceptance import TestFile @CrossSync.sync_output( "tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", diff --git a/tests/unit/data/_sync/__init__.py b/tests/unit/data/_sync/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 6194cbc15..b1764db35 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -1264,7 +1264,13 @@ def __init__(self, chunk_list, sleep_time): def __aiter__(self): return self + def __iter__(self): + return self + def __anext__(self): + return self.__next__() + + def __next__(self): self.idx += 1 if len(self.chunk_list) > self.idx: if sleep_time: @@ -1274,7 +1280,7 @@ def __anext__(self): raise chunk else: return ReadRowsResponse(chunks=[chunk]) - raise StopIteration + raise CrossSync._Sync_Impl.StopIteration def cancel(self): pass diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync/test_read_rows_acceptance.py index 25e59f53f..aa7725d5e 100644 --- a/tests/unit/data/_sync/test_read_rows_acceptance.py +++ b/tests/unit/data/_sync/test_read_rows_acceptance.py @@ -22,6 +22,11 @@ from google.cloud.bigtable_v2 import ReadRowsResponse from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.row import Row +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +if not CrossSync._Sync_Impl.is_async: + from .._async.test_read_rows_acceptance import ReadRowsTest + from .._async.test_read_rows_acceptance import TestFile class TestReadRowsAcceptance: From 515f565cbce19bc50144a8599f670ec863d82612 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 17:25:15 -0600 Subject: [PATCH 127/360] added mock to crosssync --- .../cloud/bigtable/data/_sync/cross_sync.py | 17 ++++++ tests/unit/data/_async/test_client.py | 49 ++++++++-------- tests/unit/data/_sync/test_client.py | 58 ++++++++++++------- 3 files changed, 76 insertions(+), 48 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index b20214150..ea3c54902 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -63,6 +63,14 @@ def decorator(func): def drop_method(func): return func + @classmethod + def Mock(cls, *args, **kwargs): + try: + from unittest.mock import AsyncMock # type: ignore + except ImportError: # pragma: NO COVER + from mock import AsyncMock # type: ignore + return AsyncMock(*args, **kwargs) + @classmethod def sync_output( cls, @@ -209,6 +217,15 @@ class _Sync_Impl: generated_replacements: dict[type, str] = {} + @classmethod + def Mock(cls, *args, **kwargs): + # try/except added for compatibility with python < 3.8 + try: + from unittest.mock import Mock + except ImportError: # pragma: NO COVER + from mock import Mock # type: ignore + return Mock(*args, **kwargs) + @staticmethod def wait( futures: Sequence[CrossSync._Sync_Impl.Future[T]], diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 3f729d436..39d9e772a 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -37,10 +37,8 @@ # try/except added for compatibility with python < 3.8 try: from unittest import mock - from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER import mock # type: ignore - from mock import AsyncMock # type: ignore if CrossSync.is_async: from google.api_core import grpc_helpers_async @@ -77,7 +75,6 @@ "grpc_helpers_async": "grpc_helpers", "PooledChannelAsync": "PooledChannel", "BigtableAsyncClient": "BigtableClient", - "AsyncMock": "mock.Mock", } ) class TestBigtableDataClientAsync: @@ -224,7 +221,7 @@ async def test_veneer_grpc_headers(self): async def test_channel_pool_creation(self): pool_size = 14 with mock.patch.object( - grpc_helpers_async, "create_channel", AsyncMock() + grpc_helpers_async, "create_channel", CrossSync.Mock() ) as create_channel: client = self._make_client(project="project-id", pool_size=pool_size) assert create_channel.call_count == pool_size @@ -317,7 +314,7 @@ async def test__start_background_channel_refresh(self, pool_size): # should create background tasks for each channel with mock.patch.object( - self._get_target_class(), "_ping_and_warm_instances", AsyncMock() + self._get_target_class(), "_ping_and_warm_instances", CrossSync.Mock() ) as ping_and_warm: client = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False @@ -364,7 +361,7 @@ async def test__ping_and_warm_instances(self): client_mock, *args ) ) - with mock.patch.object(CrossSync, "gather_partials", AsyncMock()) as gather: + with mock.patch.object(CrossSync, "gather_partials", CrossSync.Mock()) as gather: # gather_partials is expected to call the function passed, and return the result gather.side_effect = lambda partials, **kwargs: [None for _ in partials] channel = mock.Mock() @@ -422,7 +419,7 @@ async def test__ping_and_warm_single_instance(self): client_mock, *args ) ) - with mock.patch.object(CrossSync, "gather_partials", AsyncMock()) as gather: + with mock.patch.object(CrossSync, "gather_partials", CrossSync.Mock()) as gather: gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] channel = mock.Mock() # test with large set of instances @@ -503,7 +500,7 @@ async def test__manage_channel_ping_and_warm(self): with mock.patch.object(*sleep_tuple): # stop process after replace_channel is called client_mock.transport.replace_channel.side_effect = asyncio.CancelledError - ping_and_warm = client_mock._ping_and_warm_instances = AsyncMock() + ping_and_warm = client_mock._ping_and_warm_instances = CrossSync.Mock() # should ping and warm old channel then new if sleep > 0 try: channel_idx = 1 @@ -681,7 +678,7 @@ async def test__register_instance(self): ) mock_channels = [mock.Mock() for i in range(5)] client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = AsyncMock() + client_mock._ping_and_warm_instances = CrossSync.Mock() table_mock = mock.Mock() await self._get_target_class()._register_instance( client_mock, "instance-1", table_mock @@ -767,7 +764,7 @@ async def test__register_instance_state( ) mock_channels = [mock.Mock() for i in range(5)] client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = AsyncMock() + client_mock._ping_and_warm_instances = CrossSync.Mock() table_mock = mock.Mock() # register instances for instance, table, profile in insert_instances: @@ -1063,7 +1060,7 @@ async def test_close(self): for task in client._channel_refresh_tasks: assert not task.done() with mock.patch.object( - PooledBigtableGrpcAsyncIOTransport, "close", AsyncMock() + PooledBigtableGrpcAsyncIOTransport, "close", CrossSync.Mock() ) as close_mock: await client.close() close_mock.assert_called_once() @@ -1079,7 +1076,7 @@ async def test_close_with_timeout(self): expected_timeout = 19 client = self._make_client(project="project-id", pool_size=pool_size) tasks = list(client._channel_refresh_tasks) - with mock.patch.object(CrossSync, "wait", AsyncMock()) as wait_for_mock: + with mock.patch.object(CrossSync, "wait", CrossSync.Mock()) as wait_for_mock: await client.close(timeout=expected_timeout) wait_for_mock.assert_called_once() if CrossSync.is_async: @@ -1091,7 +1088,7 @@ async def test_close_with_timeout(self): @CrossSync.pytest async def test_context_manager(self): # context manager should close the client cleanly - close_mock = AsyncMock() + close_mock = CrossSync.Mock() true_close = None async with self._make_client(project="project-id") as client: true_close = client.close() @@ -1124,7 +1121,7 @@ def test_client_ctor_sync(self): @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestTable", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "TableAsync": "Table", "BigtableAsyncClient": "BigtableClient", "AsyncMock": "mock.Mock"}, + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "TableAsync": "Table", "BigtableAsyncClient": "BigtableClient"}, ) class TestTableAsync: def _make_client(self, *args, **kwargs): @@ -1414,7 +1411,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ profile = "profile" if include_app_profile else None with mock.patch.object( - BigtableAsyncClient, gapic_fn, AsyncMock() + BigtableAsyncClient, gapic_fn, CrossSync.Mock() ) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") async with self._make_client() as client: @@ -1441,7 +1438,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestReadRows", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "__aiter__": "__iter__", "__anext__": "__next__", "StopAsyncIteration": "StopIteration", "_ReadRowsOperationAsync": "_ReadRowsOperation", "TestTableAsync": "TestTable"}, + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "_ReadRowsOperationAsync": "_ReadRowsOperation", "TestTableAsync": "TestTable"}, ) class TestReadRowsAsync: """ @@ -1521,9 +1518,6 @@ def __iter__(self): return self async def __anext__(self): - return self.__next__() - - def __next__(self): self.idx += 1 if len(self.chunk_list) > self.idx: if sleep_time: @@ -1535,6 +1529,9 @@ def __next__(self): return ReadRowsResponse(chunks=[chunk]) raise CrossSync.StopIteration + def __next__(self): + return self.__anext__() + def cancel(self): pass @@ -2177,7 +2174,7 @@ async def mock_call(*args, **kwargs): @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestSampleRowKeys", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "AsyncMock": "mock.Mock"}, + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, ) class TestSampleRowKeysAsync: def _make_client(self, *args, **kwargs): @@ -2202,7 +2199,7 @@ async def test_sample_row_keys(self): async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream(samples) result = await table.sample_row_keys() @@ -2240,7 +2237,7 @@ async def test_sample_row_keys_default_timeout(self): default_attempt_timeout=expected_timeout, ) as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) result = await table.sample_row_keys() @@ -2263,7 +2260,7 @@ async def test_sample_row_keys_gapic_params(self): instance, table_id, app_profile_id=expected_profile ) as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) await table.sample_row_keys(attempt_timeout=expected_timeout) @@ -2294,7 +2291,7 @@ async def test_sample_row_keys_retryable_errors(self, retryable_exception): async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.side_effect = retryable_exception("mock") with pytest.raises(DeadlineExceeded) as e: @@ -2323,7 +2320,7 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.side_effect = non_retryable_exception("mock") with pytest.raises(non_retryable_exception): @@ -2483,7 +2480,7 @@ async def test_mutate_row_metadata(self, include_app_profile): async with self._make_client() as client: async with client.get_table("i", "t", app_profile_id=profile) as table: with mock.patch.object( - client._gapic_client, "mutate_row", AsyncMock() + client._gapic_client, "mutate_row", CrossSync.Mock() ) as read_rows: await table.mutate_row("rk", mock.Mock()) kwargs = read_rows.call_args_list[0].kwargs diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index b1764db35..c5fd7bc78 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -32,10 +32,8 @@ try: from unittest import mock - from unittest.mock import AsyncMock except ImportError: import mock - from mock import AsyncMock if CrossSync._Sync_Impl.is_async: pass else: @@ -181,7 +179,7 @@ def test_veneer_grpc_headers(self): def test_channel_pool_creation(self): pool_size = 14 with mock.patch.object( - grpc_helpers, "create_channel", mock.Mock() + grpc_helpers, "create_channel", CrossSync._Sync_Impl.Mock() ) as create_channel: client = self._make_client(project="project-id", pool_size=pool_size) assert create_channel.call_count == pool_size @@ -257,7 +255,9 @@ def test__start_background_channel_refresh(self, pool_size): import concurrent.futures with mock.patch.object( - self._get_target_class(), "_ping_and_warm_instances", mock.Mock() + self._get_target_class(), + "_ping_and_warm_instances", + CrossSync._Sync_Impl.Mock(), ) as ping_and_warm: client = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False @@ -285,7 +285,7 @@ def test__ping_and_warm_instances(self): ) ) with mock.patch.object( - CrossSync._Sync_Impl, "gather_partials", mock.Mock() + CrossSync._Sync_Impl, "gather_partials", CrossSync._Sync_Impl.Mock() ) as gather: gather.side_effect = lambda partials, **kwargs: [None for _ in partials] channel = mock.Mock() @@ -337,7 +337,7 @@ def test__ping_and_warm_single_instance(self): ) ) with mock.patch.object( - CrossSync._Sync_Impl, "gather_partials", mock.Mock() + CrossSync._Sync_Impl, "gather_partials", CrossSync._Sync_Impl.Mock() ) as gather: gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] channel = mock.Mock() @@ -405,7 +405,9 @@ def test__manage_channel_ping_and_warm(self): ) with mock.patch.object(*sleep_tuple): client_mock.transport.replace_channel.side_effect = asyncio.CancelledError - ping_and_warm = client_mock._ping_and_warm_instances = mock.Mock() + ping_and_warm = ( + client_mock._ping_and_warm_instances + ) = CrossSync._Sync_Impl.Mock() try: channel_idx = 1 self._get_target_class()._manage_channel(client_mock, channel_idx, 10) @@ -570,7 +572,7 @@ def test__register_instance(self): ) mock_channels = [mock.Mock() for i in range(5)] client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = mock.Mock() + client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() table_mock = mock.Mock() self._get_target_class()._register_instance( client_mock, "instance-1", table_mock @@ -646,7 +648,7 @@ def test__register_instance_state( ) mock_channels = [mock.Mock() for i in range(5)] client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = mock.Mock() + client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() table_mock = mock.Mock() for instance, table, profile in insert_instances: table_mock.table_name = table @@ -908,7 +910,7 @@ def test_close(self): for task in client._channel_refresh_tasks: assert not task.done() with mock.patch.object( - PooledBigtableGrpcTransport, "close", mock.Mock() + PooledBigtableGrpcTransport, "close", CrossSync._Sync_Impl.Mock() ) as close_mock: client.close() close_mock.assert_called_once() @@ -924,7 +926,7 @@ def test_close_with_timeout(self): client = self._make_client(project="project-id", pool_size=pool_size) tasks = list(client._channel_refresh_tasks) with mock.patch.object( - CrossSync._Sync_Impl, "wait", mock.Mock() + CrossSync._Sync_Impl, "wait", CrossSync._Sync_Impl.Mock() ) as wait_for_mock: client.close(timeout=expected_timeout) wait_for_mock.assert_called_once() @@ -935,7 +937,7 @@ def test_close_with_timeout(self): client.close() def test_context_manager(self): - close_mock = mock.Mock() + close_mock = CrossSync._Sync_Impl.Mock() true_close = None with self._make_client(project="project-id") as client: true_close = client.close() @@ -1168,7 +1170,9 @@ def test_customizable_retryable_errors( def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): """check that all requests attach proper metadata headers""" profile = "profile" if include_app_profile else None - with mock.patch.object(BigtableClient, gapic_fn, mock.Mock()) as gapic_mock: + with mock.patch.object( + BigtableClient, gapic_fn, CrossSync._Sync_Impl.Mock() + ) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") with self._make_client() as client: table = Table(client, "instance-id", "table-id", profile) @@ -1268,9 +1272,6 @@ def __iter__(self): return self def __anext__(self): - return self.__next__() - - def __next__(self): self.idx += 1 if len(self.chunk_list) > self.idx: if sleep_time: @@ -1282,6 +1283,9 @@ def __next__(self): return ReadRowsResponse(chunks=[chunk]) raise CrossSync._Sync_Impl.StopIteration + def __next__(self): + return self.__anext__() + def cancel(self): pass @@ -1861,7 +1865,9 @@ def test_sample_row_keys(self): with self._make_client() as client: with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream(samples) result = table.sample_row_keys() @@ -1895,7 +1901,9 @@ def test_sample_row_keys_default_timeout(self): default_attempt_timeout=expected_timeout, ) as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) result = table.sample_row_keys() @@ -1915,7 +1923,9 @@ def test_sample_row_keys_gapic_params(self): instance, table_id, app_profile_id=expected_profile ) as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) table.sample_row_keys(attempt_timeout=expected_timeout) @@ -1940,7 +1950,9 @@ def test_sample_row_keys_retryable_errors(self, retryable_exception): with self._make_client() as client: with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), ) as sample_row_keys: sample_row_keys.side_effect = retryable_exception("mock") with pytest.raises(DeadlineExceeded) as e: @@ -1966,7 +1978,9 @@ def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): with self._make_client() as client: with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", mock.Mock() + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), ) as sample_row_keys: sample_row_keys.side_effect = non_retryable_exception("mock") with pytest.raises(non_retryable_exception): @@ -2100,7 +2114,7 @@ def test_mutate_row_metadata(self, include_app_profile): with self._make_client() as client: with client.get_table("i", "t", app_profile_id=profile) as table: with mock.patch.object( - client._gapic_client, "mutate_row", AsyncMock() + client._gapic_client, "mutate_row", CrossSync._Sync_Impl.Mock() ) as read_rows: table.mutate_row("rk", mock.Mock()) kwargs = read_rows.call_args_list[0].kwargs From 4c6bac24c43b6affaabfadcae9aee62ee1d81c20 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 17:42:20 -0600 Subject: [PATCH 128/360] more targeted replacements --- tests/unit/data/_async/test_client.py | 73 ++++++++++----------------- tests/unit/data/_sync/test_client.py | 12 +++-- 2 files changed, 34 insertions(+), 51 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 39d9e772a..433de9b72 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -69,8 +69,6 @@ "tests.unit.data._sync.test_client.TestBigtableDataClient", replace_symbols={ "TestTableAsync": "TestTable", - "BigtableDataClientAsync": "BigtableDataClient", - "TableAsync": "Table", "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", "grpc_helpers_async": "grpc_helpers", "PooledChannelAsync": "PooledChannel", @@ -79,6 +77,7 @@ ) class TestBigtableDataClientAsync: @staticmethod + @CrossSync.convert(replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}) def _get_target_class(): return BigtableDataClientAsync @@ -236,7 +235,6 @@ async def test_channel_pool_creation(self): @CrossSync.pytest async def test_channel_pool_rotation(self): pool_size = 7 - with mock.patch.object(PooledChannelAsync, "next_channel") as next_channel: client = self._make_client(project="project-id", pool_size=pool_size) assert len(client.transport._grpc_channel._pool) == pool_size @@ -1119,15 +1117,14 @@ def test_client_ctor_sync(self): assert client.project == "project-id" assert client._channel_refresh_tasks == [] -@CrossSync.sync_output( - "tests.unit.data._sync.test_client.TestTable", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "TableAsync": "Table", "BigtableAsyncClient": "BigtableClient"}, -) +@CrossSync.sync_output("tests.unit.data._sync.test_client.TestTable") class TestTableAsync: + @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @staticmethod + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def _get_target_class(): return TableAsync @@ -1199,14 +1196,12 @@ async def test_table_ctor_defaults(self): """ should provide default timeout values and app_profile_id """ - from google.cloud.bigtable.data._async.client import TableAsync - expected_table_id = "table-id" expected_instance_id = "instance-id" client = self._make_client() assert not client._active_instances - table = TableAsync( + table = self._get_target_class()( client, expected_instance_id, expected_table_id, @@ -1229,8 +1224,6 @@ async def test_table_ctor_invalid_timeout_values(self): """ bad timeout values should raise ValueError """ - from google.cloud.bigtable.data._async.client import TableAsync - client = self._make_client() timeout_pairs = [ @@ -1246,18 +1239,16 @@ async def test_table_ctor_invalid_timeout_values(self): ] for operation_timeout, attempt_timeout in timeout_pairs: with pytest.raises(ValueError) as e: - TableAsync(client, "", "", **{attempt_timeout: -1}) + self._get_target_class()(client, "", "", **{attempt_timeout: -1}) assert "attempt_timeout must be greater than 0" in str(e.value) with pytest.raises(ValueError) as e: - TableAsync(client, "", "", **{operation_timeout: -1}) + self._get_target_class()(client, "", "", **{operation_timeout: -1}) assert "operation_timeout must be greater than 0" in str(e.value) await client.close() @CrossSync.drop_method def test_table_ctor_sync(self): # initializing client in a sync context should raise RuntimeError - from google.cloud.bigtable.data._async.client import TableAsync - client = mock.Mock() with pytest.raises(RuntimeError) as e: TableAsync(client, "instance-id", "table-id") @@ -1405,17 +1396,16 @@ async def test_customizable_retryable_errors( ) @pytest.mark.parametrize("include_app_profile", [True, False]) @CrossSync.pytest + @CrossSync.convert(replace_symbols={"BigtableAsyncClient": "BigtableClient"}) async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): """check that all requests attach proper metadata headers""" - from google.cloud.bigtable.data import TableAsync - profile = "profile" if include_app_profile else None with mock.patch.object( BigtableAsyncClient, gapic_fn, CrossSync.Mock() ) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") async with self._make_client() as client: - table = TableAsync(client, "instance-id", "table-id", profile) + table = self._get_target_class()(client, "instance-id", "table-id", profile) try: test_fn = table.__getattribute__(fn_name) maybe_stream = await test_fn(*fn_args) @@ -1436,22 +1426,22 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ else: assert "app_profile_id=" not in goog_metadata -@CrossSync.sync_output( - "tests.unit.data._sync.test_client.TestReadRows", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "_ReadRowsOperationAsync": "_ReadRowsOperation", "TestTableAsync": "TestTable"}, -) +@CrossSync.sync_output("tests.unit.data._sync.test_client.TestReadRows") class TestReadRowsAsync: """ Tests for table.read_rows and related methods. """ @staticmethod + @CrossSync.convert(replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"}) def _get_operation_class(): return _ReadRowsOperationAsync + @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @CrossSync.convert(replace_symbols={"TestTableAsync": "TestTable"}) def _make_table(self, *args, **kwargs): client_mock = mock.Mock() client_mock._register_instance.side_effect = ( @@ -1948,11 +1938,9 @@ async def test_row_exists(self, return_value, expected_result): assert query.filter._to_dict() == expected_filter -@CrossSync.sync_output( - "tests.unit.data._sync.test_client.TestReadRowsSharded", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient", "TestReadRowsAsync": "TestReadRows"}, -) +@CrossSync.sync_output("tests.unit.data._sync.test_client.TestReadRowsSharded") class TestReadRowsShardedAsync: + @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -1965,6 +1953,7 @@ async def test_read_rows_sharded_empty_query(self): assert "empty sharded_query" in str(exc.value) @CrossSync.pytest + @CrossSync.convert(replace_symbols={"TestReadRowsAsync": "TestReadRows"}) async def test_read_rows_sharded_multiple_queries(self): """ Test with multiple queries. Should return results from both @@ -2172,11 +2161,9 @@ async def mock_call(*args, **kwargs): ) -@CrossSync.sync_output( - "tests.unit.data._sync.test_client.TestSampleRowKeys", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, -) +@CrossSync.sync_output("tests.unit.data._sync.test_client.TestSampleRowKeys") class TestSampleRowKeysAsync: + @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2327,11 +2314,9 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio await table.sample_row_keys() -@CrossSync.sync_output( - "tests.unit.data._sync.test_client.TestMutateRow", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, -) +@CrossSync.sync_output("tests.unit.data._sync.test_client.TestMutateRow",) class TestMutateRowAsync: + @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2506,11 +2491,9 @@ async def test_mutate_row_no_mutations(self, mutations): assert e.value.args[0] == "No mutations provided" -@CrossSync.sync_output( - "tests.unit.data._sync.test_client.TestBulkMutateRows", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, -) +@CrossSync.sync_output("tests.unit.data._sync.test_client.TestBulkMutateRows",) class TestBulkMutateRowsAsync: + @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2889,11 +2872,9 @@ async def test_bulk_mutate_error_recovery(self): await table.bulk_mutate_rows(entries, operation_timeout=1000) -@CrossSync.sync_output( - "tests.unit.data._sync.test_client.TestCheckAndMutateRow", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, -) +@CrossSync.sync_output("tests.unit.data._sync.test_client.TestCheckAndMutateRow") class TestCheckAndMutateRowAsync: + @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -3044,11 +3025,9 @@ async def test_check_and_mutate_mutations_parsing(self): ) -@CrossSync.sync_output( - "tests.unit.data._sync.test_client.TestReadModifyWriteRow", - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}, -) +@CrossSync.sync_output("tests.unit.data._sync.test_client.TestReadModifyWriteRow") class TestReadModifyWriteRowAsync: + @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index c5fd7bc78..86f235b79 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -1027,7 +1027,9 @@ def test_table_ctor_defaults(self): expected_instance_id = "instance-id" client = self._make_client() assert not client._active_instances - table = Table(client, expected_instance_id, expected_table_id) + table = self._get_target_class()( + client, expected_instance_id, expected_table_id + ) CrossSync._Sync_Impl.yield_to_event_loop() assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id @@ -1057,10 +1059,10 @@ def test_table_ctor_invalid_timeout_values(self): ] for operation_timeout, attempt_timeout in timeout_pairs: with pytest.raises(ValueError) as e: - Table(client, "", "", **{attempt_timeout: -1}) + self._get_target_class()(client, "", "", **{attempt_timeout: -1}) assert "attempt_timeout must be greater than 0" in str(e.value) with pytest.raises(ValueError) as e: - Table(client, "", "", **{operation_timeout: -1}) + self._get_target_class()(client, "", "", **{operation_timeout: -1}) assert "operation_timeout must be greater than 0" in str(e.value) client.close() @@ -1175,7 +1177,9 @@ def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): ) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") with self._make_client() as client: - table = Table(client, "instance-id", "table-id", profile) + table = self._get_target_class()( + client, "instance-id", "table-id", profile + ) try: test_fn = table.__getattribute__(fn_name) maybe_stream = test_fn(*fn_args) From 998829e67259efcaa7c7e723fa7b04ef64242aae Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 3 Jul 2024 23:46:07 -0600 Subject: [PATCH 129/360] ran blacken --- .../bigtable/data/_async/_mutate_rows.py | 12 +- .../cloud/bigtable/data/_async/_read_rows.py | 6 +- google/cloud/bigtable/data/_async/client.py | 27 +++-- .../bigtable/data/_async/mutations_batcher.py | 8 +- .../cloud/bigtable/data/_sync/cross_sync.py | 27 ++++- .../cloud/bigtable/data/_sync/transformers.py | 108 ++++++++++++------ tests/unit/data/_async/test__mutate_rows.py | 12 +- tests/unit/data/_async/test_client.py | 86 ++++++++++---- .../data/_async/test_mutations_batcher.py | 5 +- .../data/_async/test_read_rows_acceptance.py | 2 + 10 files changed, 211 insertions(+), 82 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 47715cb99..f63cc617a 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -48,6 +48,7 @@ BigtableClient, ) + @dataclass class _EntryWithProto: """ @@ -57,6 +58,7 @@ class _EntryWithProto: entry: RowMutationEntry proto: types_pb.MutateRowsRequest.Entry + if not CrossSync.is_async: from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto @@ -83,10 +85,12 @@ class _MutateRowsOperationAsync: If not specified, the request will run until operation_timeout is reached. """ - @CrossSync.convert(replace_symbols={ - "BigtableAsyncClient": "BigtableClient", - "TableAsync": "Table", - }) + @CrossSync.convert( + replace_symbols={ + "BigtableAsyncClient": "BigtableClient", + "TableAsync": "Table", + } + ) def __init__( self, gapic_client: "BigtableAsyncClient", diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index f786de98d..62440ad92 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -49,9 +49,11 @@ class _ResetRow(Exception): def __init__(self, chunk): self.chunk = chunk + if not CrossSync.is_async: from google.cloud.bigtable.data._async._read_rows import _ResetRow + @CrossSync.sync_output( "google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", ) @@ -219,7 +221,9 @@ async def chunk_stream( current_key = None @staticmethod - @CrossSync.convert(replace_symbols={"__aiter__": "__iter__", "__anext__": "__next__"}) + @CrossSync.convert( + replace_symbols={"__aiter__": "__iter__", "__anext__": "__next__"} + ) async def merge_rows( chunks: CrossSync.Iterable[ReadRowsResponsePB.CellChunk] | None, ) -> CrossSync.Iterable[Row]: diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 9dd001543..307b41775 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -114,11 +114,13 @@ "google.cloud.bigtable.data._sync.client.BigtableDataClient", ) class BigtableDataClientAsync(ClientWithProject): - @CrossSync.convert(replace_symbols={ - "BigtableAsyncClient": "BigtableClient", - "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", - "AsyncPooledChannel": "PooledChannel", - }) + @CrossSync.convert( + replace_symbols={ + "BigtableAsyncClient": "BigtableClient", + "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", + "AsyncPooledChannel": "PooledChannel", + } + ) def __init__( self, *, @@ -498,7 +500,9 @@ class TableAsync: each call """ - @CrossSync.convert(replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}) + @CrossSync.convert( + replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} + ) def __init__( self, client: BigtableDataClientAsync, @@ -619,7 +623,12 @@ def __init__( f"{self.__class__.__name__} must be created within an async event loop context." ) from e - @CrossSync.convert(replace_symbols={"AsyncIterable": "Iterable", "_ReadRowsOperationAsync": "_ReadRowsOperation"}) + @CrossSync.convert( + replace_symbols={ + "AsyncIterable": "Iterable", + "_ReadRowsOperationAsync": "_ReadRowsOperation", + } + ) async def read_rows_stream( self, query: ReadRowsQuery, @@ -1113,7 +1122,9 @@ async def mutate_row( exception_factory=_helpers._retry_exception_factory, ) - @CrossSync.convert(replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"}) + @CrossSync.convert( + replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"} + ) async def bulk_mutate_rows( self, mutation_entries: list[RowMutationEntry], diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index d723c4118..39afabdde 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -213,7 +213,9 @@ class MutationsBatcherAsync: Defaults to the Table's default_mutate_rows_retryable_errors. """ - @CrossSync.convert(replace_symbols={"TableAsync": "Table", "_FlowControlAsync": "_FlowControl"}) + @CrossSync.convert( + replace_symbols={"TableAsync": "Table", "_FlowControlAsync": "_FlowControl"} + ) def __init__( self, table: TableAsync, @@ -358,7 +360,9 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): self._entries_processed_since_last_raise += len(new_entries) self._add_exceptions(found_exceptions) - @CrossSync.convert(replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"}) + @CrossSync.convert( + replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"} + ) async def _execute_mutate_rows( self, batch: list[RowMutationEntry] ) -> list[FailedMutationEntryError]: diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index ea3c54902..60644f7df 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -14,7 +14,22 @@ # from __future__ import annotations -from typing import TypeVar, Any, Awaitable, Callable, Coroutine, Sequence, Union, AsyncIterable, AsyncIterator, AsyncGenerator, Iterable, Iterator, Generator, TYPE_CHECKING +from typing import ( + TypeVar, + Any, + Awaitable, + Callable, + Coroutine, + Sequence, + Union, + AsyncIterable, + AsyncIterator, + AsyncGenerator, + Iterable, + Iterator, + Generator, + TYPE_CHECKING, +) import asyncio import sys @@ -53,7 +68,9 @@ class CrossSync: generated_replacements: dict[type, str] = {} @staticmethod - def convert(*, sync_name: str|None=None, replace_symbols: dict[str, str]|None=None): + def convert( + *, sync_name: str | None = None, replace_symbols: dict[str, str] | None = None + ): def decorator(func): return func @@ -76,18 +93,20 @@ def sync_output( cls, sync_path: str, *, - replace_symbols: dict["str", "str" | None ] | None = None, + replace_symbols: dict["str", "str" | None] | None = None, mypy_ignore: list[str] | None = None, include_file_imports: bool = False, ): # return the async class unchanged def decorator(async_cls): return async_cls + return decorator @staticmethod def pytest(func): import pytest + return pytest.mark.asyncio(func) @staticmethod @@ -301,6 +320,7 @@ def create_task( def yield_to_event_loop() -> None: pass + from google.cloud.bigtable.data._sync import transformers if __name__ == "__main__": @@ -311,6 +331,7 @@ def yield_to_event_loop() -> None: import itertools import black import autoflake + # find all cross_sync decorated classes search_root = sys.argv[1] # cross_sync_classes = load_classes_from_dir(search_root)\ diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index d3673f1b8..c1fd434bd 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -2,8 +2,8 @@ from dataclasses import dataclass, field -class SymbolReplacer(ast.NodeTransformer): +class SymbolReplacer(ast.NodeTransformer): def __init__(self, replacements): self.replacements = replacements @@ -48,7 +48,6 @@ def visit_Str(self, node): class AsyncToSync(ast.NodeTransformer): - def visit_Await(self, node): return self.visit(node.value) @@ -90,40 +89,55 @@ def visit_ListComp(self, node): generator.is_async = False return self.generic_visit(node) -class HandleCrossSyncDecorators(ast.NodeTransformer): +class HandleCrossSyncDecorators(ast.NodeTransformer): def visit_FunctionDef(self, node): if hasattr(node, "decorator_list"): found_list, node.decorator_list = node.decorator_list, [] for decorator in found_list: if "CrossSync" in ast.dump(decorator): - decorator_type = decorator.func.attr if hasattr(decorator, "func") else decorator.attr + decorator_type = ( + decorator.func.attr + if hasattr(decorator, "func") + else decorator.attr + ) if decorator_type == "convert": for subcommand in decorator.keywords: if subcommand.arg == "sync_name": node.name = subcommand.value.s if subcommand.arg == "replace_symbols": - replacements = {subcommand.value.keys[i].s: subcommand.value.values[i].s for i in range(len(subcommand.value.keys))} + replacements = { + subcommand.value.keys[i] + .s: subcommand.value.values[i] + .s + for i in range(len(subcommand.value.keys)) + } node = SymbolReplacer(replacements).visit(node) elif decorator_type == "pytest": pass elif decorator_type == "drop_method": return None else: - raise ValueError(f"Unsupported CrossSync decorator: {decorator_type}") + raise ValueError( + f"Unsupported CrossSync decorator: {decorator_type}" + ) else: # add non-crosssync decorators back node.decorator_list.append(decorator) return node + @dataclass class CrossSyncFileArtifact: """ Used to track an output file location. Collects a number of converted classes, and then writes them to disk """ + file_path: str - imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = field(default_factory=list) + imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = field( + default_factory=list + ) converted_classes: list[ast.ClassDef] = field(default_factory=list) contained_classes: set[str] = field(default_factory=set) mypy_ignore: list[str] = field(default_factory=list) @@ -139,43 +153,52 @@ def render(self, with_black=True, save_to_disk=False) -> str: "# Copyright 2024 Google LLC\n" "#\n" '# Licensed under the Apache License, Version 2.0 (the "License");\n' - '# you may not use this file except in compliance with the License.\n' - '# You may obtain a copy of the License at\n' - '#\n' - '# http://www.apache.org/licenses/LICENSE-2.0\n' - '#\n' - '# Unless required by applicable law or agreed to in writing, software\n' + "# you may not use this file except in compliance with the License.\n" + "# You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing, software\n" '# distributed under the License is distributed on an "AS IS" BASIS,\n' - '# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n' - '# See the License for the specific language governing permissions and\n' - '# limitations under the License.\n' - '#\n' - '# This file is automatically generated by CrossSync. Do not edit manually.\n' + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + "# See the License for the specific language governing permissions and\n" + "# limitations under the License.\n" + "#\n" + "# This file is automatically generated by CrossSync. Do not edit manually.\n" ) if self.mypy_ignore: - full_str += f'\n# mypy: disable-error-code="{",".join(self.mypy_ignore)}"\n\n' + full_str += ( + f'\n# mypy: disable-error-code="{",".join(self.mypy_ignore)}"\n\n' + ) full_str += "\n".join([ast.unparse(node) for node in self.imports]) full_str += "\n\n" full_str += "\n".join([ast.unparse(node) for node in self.converted_classes]) if with_black: import black import autoflake - full_str = black.format_str(autoflake.fix_code(full_str, remove_all_unused_imports=True), mode=black.FileMode()) + + full_str = black.format_str( + autoflake.fix_code(full_str, remove_all_unused_imports=True), + mode=black.FileMode(), + ) if save_to_disk: with open(self.file_path, "w") as f: f.write(full_str) return full_str -class CrossSyncClassParser(ast.NodeTransformer): +class CrossSyncClassParser(ast.NodeTransformer): def __init__(self, file_path): self.in_path = file_path self._artifact_dict = {} self.imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = [] - self.cross_sync_converter = SymbolReplacer({"CrossSync": "CrossSync._Sync_Impl"}) - + self.cross_sync_converter = SymbolReplacer( + {"CrossSync": "CrossSync._Sync_Impl"} + ) - def convert_file(self, artifacts:set[CrossSyncFileArtifact]|None=None) -> set[CrossSyncFileArtifact]: + def convert_file( + self, artifacts: set[CrossSyncFileArtifact] | None = None + ) -> set[CrossSyncFileArtifact]: """ Called to run a file through the transformer. If any classes are marked with a CrossSync decorator, they will be transformed and added to an artifact for the output file @@ -196,7 +219,10 @@ def visit_ClassDef(self, node): """ for decorator in node.decorator_list: if "CrossSync" in ast.dump(decorator): - kwargs = {kw.arg: self._convert_ast_to_py(kw.value) for kw in decorator.keywords} + kwargs = { + kw.arg: self._convert_ast_to_py(kw.value) + for kw in decorator.keywords + } # find the path to write the sync class to sync_path = kwargs.pop("sync_path", None) if not sync_path: @@ -204,21 +230,31 @@ def visit_ClassDef(self, node): out_file = "/".join(sync_path.rsplit(".")[:-1]) + ".py" sync_cls_name = sync_path.rsplit(".", 1)[-1] # find the artifact file for the save location - output_artifact = self._artifact_dict.get(out_file, CrossSyncFileArtifact(out_file)) + output_artifact = self._artifact_dict.get( + out_file, CrossSyncFileArtifact(out_file) + ) # write converted class details if not already present if sync_cls_name not in output_artifact.contained_classes: converted = self._transform_class(node, sync_cls_name, **kwargs) output_artifact.converted_classes.append(converted) # handle file-level mypy ignores - mypy_ignores = [s for s in kwargs.get("mypy_ignore", []) if s not in output_artifact.mypy_ignore] + mypy_ignores = [ + s + for s in kwargs.get("mypy_ignore", []) + if s not in output_artifact.mypy_ignore + ] output_artifact.mypy_ignore.extend(mypy_ignores) # handle file-level imports - if not output_artifact.imports and kwargs.get("include_file_imports", True): + if not output_artifact.imports and kwargs.get( + "include_file_imports", True + ): output_artifact.imports = self.imports self._artifact_dict[out_file] = output_artifact return node - def _transform_class(self, cls_ast: ast.ClassDef, new_name:str, replace_symbols=None, **kwargs) -> ast.ClassDef: + def _transform_class( + self, cls_ast: ast.ClassDef, new_name: str, replace_symbols=None, **kwargs + ) -> ast.ClassDef: """ Transform async class into sync one, by running through a series of transformers """ @@ -226,7 +262,9 @@ def _transform_class(self, cls_ast: ast.ClassDef, new_name:str, replace_symbols= cls_ast.name = new_name # strip CrossSync decorators if hasattr(cls_ast, "decorator_list"): - cls_ast.decorator_list = [d for d in cls_ast.decorator_list if "CrossSync" not in ast.dump(d)] + cls_ast.decorator_list = [ + d for d in cls_ast.decorator_list if "CrossSync" not in ast.dump(d) + ] # convert class contents cls_ast = AsyncToSync().visit(cls_ast) cls_ast = self.cross_sync_converter.visit(cls_ast) @@ -235,7 +273,9 @@ def _transform_class(self, cls_ast: ast.ClassDef, new_name:str, replace_symbols= cls_ast = HandleCrossSyncDecorators().visit(cls_ast) return cls_ast - def _get_imports(self, tree:ast.Module) -> list[ast.Import | ast.ImportFrom | ast.Try | ast.If]: + def _get_imports( + self, tree: ast.Module + ) -> list[ast.Import | ast.ImportFrom | ast.Try | ast.If]: """ Grab the imports from the top of the file """ @@ -254,6 +294,8 @@ def _convert_ast_to_py(self, ast_node): if isinstance(ast_node, ast.List): return [self._convert_ast_to_py(node) for node in ast_node.elts] if isinstance(ast_node, ast.Dict): - return {self._convert_ast_to_py(k): self._convert_ast_to_py(v) for k, v in zip(ast_node.keys, ast_node.values)} + return { + self._convert_ast_to_py(k): self._convert_ast_to_py(v) + for k, v in zip(ast_node.keys, ast_node.values) + } raise ValueError(f"Unsupported type {type(ast_node)}") - diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 55a6fdd40..20a7d7e47 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -42,7 +42,9 @@ def _target_class(self): return _MutateRowsOperationAsync else: - from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation + from google.cloud.bigtable.data._sync._mutate_rows import ( + _MutateRowsOperation, + ) return _MutateRowsOperation @@ -181,9 +183,7 @@ async def test_mutate_rows_operation(self): await instance.start() assert attempt_mock.call_count == 1 - @pytest.mark.parametrize( - "exc_type", [RuntimeError, ZeroDivisionError, Forbidden] - ) + @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) @CrossSync.pytest async def test_mutate_rows_attempt_exception(self, exc_type): """ @@ -209,9 +209,7 @@ async def test_mutate_rows_attempt_exception(self, exc_type): assert len(instance.errors) == 2 assert len(instance.remaining_indices) == 0 - @pytest.mark.parametrize( - "exc_type", [RuntimeError, ZeroDivisionError, Forbidden] - ) + @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) @CrossSync.pytest async def test_mutate_rows_exception(self, exc_type): """ diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 433de9b72..a197508c8 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -52,7 +52,10 @@ PooledChannel as PooledChannelAsync, ) from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - from google.cloud.bigtable.data._async.client import TableAsync, BigtableDataClientAsync + from google.cloud.bigtable.data._async.client import ( + TableAsync, + BigtableDataClientAsync, + ) else: from google.api_core import grpc_helpers from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient @@ -65,6 +68,7 @@ from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation from google.cloud.bigtable.data._sync.client import Table, BigtableDataClient + @CrossSync.sync_output( "tests.unit.data._sync.test_client.TestBigtableDataClient", replace_symbols={ @@ -73,11 +77,13 @@ "grpc_helpers_async": "grpc_helpers", "PooledChannelAsync": "PooledChannel", "BigtableAsyncClient": "BigtableClient", - } + }, ) class TestBigtableDataClientAsync: @staticmethod - @CrossSync.convert(replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}) + @CrossSync.convert( + replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} + ) def _get_target_class(): return BigtableDataClientAsync @@ -359,7 +365,9 @@ async def test__ping_and_warm_instances(self): client_mock, *args ) ) - with mock.patch.object(CrossSync, "gather_partials", CrossSync.Mock()) as gather: + with mock.patch.object( + CrossSync, "gather_partials", CrossSync.Mock() + ) as gather: # gather_partials is expected to call the function passed, and return the result gather.side_effect = lambda partials, **kwargs: [None for _ in partials] channel = mock.Mock() @@ -417,7 +425,9 @@ async def test__ping_and_warm_single_instance(self): client_mock, *args ) ) - with mock.patch.object(CrossSync, "gather_partials", CrossSync.Mock()) as gather: + with mock.patch.object( + CrossSync, "gather_partials", CrossSync.Mock() + ) as gather: gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] channel = mock.Mock() # test with large set of instances @@ -494,7 +504,9 @@ async def test__manage_channel_ping_and_warm(self): new_channel = mock.Mock() client_mock.transport.grpc_channel._create_channel.return_value = new_channel # should ping an warm all new channels, and old channels if sleeping - sleep_tuple = (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") + sleep_tuple = ( + (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") + ) with mock.patch.object(*sleep_tuple): # stop process after replace_channel is called client_mock.transport.replace_channel.side_effect = asyncio.CancelledError @@ -548,7 +560,9 @@ async def test__manage_channel_sleeps( with mock.patch.object(time, "time") as time_mock: time_mock.return_value = 0 sleep_tuple = ( - (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") + (asyncio, "sleep") + if CrossSync.is_async + else (threading.Event, "wait") ) with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = [None for i in range(num_cycles - 1)] + [ @@ -569,7 +583,9 @@ async def test__manage_channel_sleeps( if CrossSync.is_async: total_sleep = sum([call[0][0] for call in sleep.call_args_list]) else: - total_sleep = sum([call[1]["timeout"] for call in sleep.call_args_list]) + total_sleep = sum( + [call[1]["timeout"] for call in sleep.call_args_list] + ) assert ( abs(total_sleep - expected_sleep) < 0.1 ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" @@ -580,7 +596,9 @@ async def test__manage_channel_random(self): import random import threading - sleep_tuple = (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") + sleep_tuple = ( + (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") + ) with mock.patch.object(*sleep_tuple) as sleep: with mock.patch.object(random, "uniform") as uniform: uniform.return_value = 0 @@ -1117,9 +1135,12 @@ def test_client_ctor_sync(self): assert client.project == "project-id" assert client._channel_refresh_tasks == [] + @CrossSync.sync_output("tests.unit.data._sync.test_client.TestTable") class TestTableAsync: - @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -1405,7 +1426,9 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ ) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") async with self._make_client() as client: - table = self._get_target_class()(client, "instance-id", "table-id", profile) + table = self._get_target_class()( + client, "instance-id", "table-id", profile + ) try: test_fn = table.__getattribute__(fn_name) maybe_stream = await test_fn(*fn_args) @@ -1426,6 +1449,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ else: assert "app_profile_id=" not in goog_metadata + @CrossSync.sync_output("tests.unit.data._sync.test_client.TestReadRows") class TestReadRowsAsync: """ @@ -1433,11 +1457,15 @@ class TestReadRowsAsync: """ @staticmethod - @CrossSync.convert(replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"}) + @CrossSync.convert( + replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} + ) def _get_operation_class(): return _ReadRowsOperationAsync - @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -1940,7 +1968,9 @@ async def test_row_exists(self, return_value, expected_result): @CrossSync.sync_output("tests.unit.data._sync.test_client.TestReadRowsSharded") class TestReadRowsShardedAsync: - @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2163,7 +2193,9 @@ async def mock_call(*args, **kwargs): @CrossSync.sync_output("tests.unit.data._sync.test_client.TestSampleRowKeys") class TestSampleRowKeysAsync: - @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2314,9 +2346,13 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio await table.sample_row_keys() -@CrossSync.sync_output("tests.unit.data._sync.test_client.TestMutateRow",) +@CrossSync.sync_output( + "tests.unit.data._sync.test_client.TestMutateRow", +) class TestMutateRowAsync: - @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2491,9 +2527,13 @@ async def test_mutate_row_no_mutations(self, mutations): assert e.value.args[0] == "No mutations provided" -@CrossSync.sync_output("tests.unit.data._sync.test_client.TestBulkMutateRows",) +@CrossSync.sync_output( + "tests.unit.data._sync.test_client.TestBulkMutateRows", +) class TestBulkMutateRowsAsync: - @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2874,7 +2914,9 @@ async def test_bulk_mutate_error_recovery(self): @CrossSync.sync_output("tests.unit.data._sync.test_client.TestCheckAndMutateRow") class TestCheckAndMutateRowAsync: - @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -3027,7 +3069,9 @@ async def test_check_and_mutate_mutations_parsing(self): @CrossSync.sync_output("tests.unit.data._sync.test_client.TestReadModifyWriteRow") class TestReadModifyWriteRowAsync: - @CrossSync.convert(replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"}) + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 8ef05326d..21ede35bf 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -32,9 +32,7 @@ from mock import AsyncMock # type: ignore -@CrossSync.sync_output( - "tests.unit.data._sync.test_mutations_batcher.Test_FlowControl" -) +@CrossSync.sync_output("tests.unit.data._sync.test_mutations_batcher.Test_FlowControl") class Test_FlowControl: @staticmethod def _target_class(): @@ -311,6 +309,7 @@ async def test_add_to_flow_oversize(self): ] assert len(count_results) == 1 + @CrossSync.sync_output( "tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" ) diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 29e2344c9..8ec1b67e9 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -52,10 +52,12 @@ class TestFile(proto.Message): __test__ = False read_rows_tests = proto.RepeatedField(proto.MESSAGE, number=1, message=ReadRowsTest) + if not CrossSync.is_async: from .._async.test_read_rows_acceptance import ReadRowsTest from .._async.test_read_rows_acceptance import TestFile + @CrossSync.sync_output( "tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", ) From 40e961e401de8ce45f72af72416678c72ab43536 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 4 Jul 2024 00:18:27 -0600 Subject: [PATCH 130/360] fixed some lint issues --- .../bigtable/data/_async/_mutate_rows.py | 9 +++---- .../cloud/bigtable/data/_async/_read_rows.py | 9 +++---- .../bigtable/data/_async/mutations_batcher.py | 4 ++- .../cloud/bigtable/data/_sync/cross_sync.py | 26 ++++++------------- .../cloud/bigtable/data/_sync/transformers.py | 23 +++++++++++++--- tests/unit/data/_async/test_client.py | 20 +++++++++----- .../data/_async/test_read_rows_acceptance.py | 13 +++++----- 7 files changed, 56 insertions(+), 48 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index f63cc617a..4e4e6b491 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -34,6 +34,9 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if not CrossSync.is_async: + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto + if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -50,7 +53,7 @@ @dataclass -class _EntryWithProto: +class _EntryWithProto: # noqa: F811 """ A dataclass to hold a RowMutationEntry and its corresponding proto representation. """ @@ -59,10 +62,6 @@ class _EntryWithProto: proto: types_pb.MutateRowsRequest.Entry -if not CrossSync.is_async: - from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto - - @CrossSync.sync_output( "google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", ) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 62440ad92..15dd2b7bf 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -17,7 +17,6 @@ from typing import ( TYPE_CHECKING, - Awaitable, Sequence, ) @@ -37,6 +36,8 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if not CrossSync.is_async: + from google.cloud.bigtable.data._async._read_rows import _ResetRow if TYPE_CHECKING: if CrossSync.is_async: @@ -45,15 +46,11 @@ from google.cloud.bigtable.data._sync.client import Table # noqa: F401 -class _ResetRow(Exception): +class _ResetRow(Exception): # noqa: F811 def __init__(self, chunk): self.chunk = chunk -if not CrossSync.is_async: - from google.cloud.bigtable.data._async._read_rows import _ResetRow - - @CrossSync.sync_output( "google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", ) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 39afabdde..b7e55e9e1 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -37,7 +37,9 @@ if CrossSync.is_async: from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync else: - from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation + from google.cloud.bigtable.data._sync._mutate_rows import ( # noqa: F401 + _MutateRowsOperation, + ) if TYPE_CHECKING: diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 60644f7df..e857543f0 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -17,7 +17,6 @@ from typing import ( TypeVar, Any, - Awaitable, Callable, Coroutine, Sequence, @@ -25,11 +24,9 @@ AsyncIterable, AsyncIterator, AsyncGenerator, - Iterable, - Iterator, - Generator, TYPE_CHECKING, ) +import typing import asyncio import sys @@ -60,7 +57,7 @@ class CrossSync: Semaphore: TypeAlias = asyncio.Semaphore StopIteration: TypeAlias = StopAsyncIteration # type annotations - Awaitable: TypeAlias = Awaitable + Awaitable: TypeAlias = typing.Awaitable Iterable: TypeAlias = AsyncIterable Iterator: TypeAlias = AsyncIterator Generator: TypeAlias = AsyncGenerator @@ -227,12 +224,12 @@ class _Sync_Impl: Task: TypeAlias = concurrent.futures.Future Event: TypeAlias = threading.Event Semaphore: TypeAlias = threading.Semaphore - StopIteration: TypeAlias = StopAsyncIteration + StopIteration: TypeAlias = StopIteration # type annotations Awaitable: TypeAlias = Union[T] - Iterable: TypeAlias = Iterable - Iterator: TypeAlias = Iterator - Generator: TypeAlias = Generator + Iterable: TypeAlias = typing.Iterable + Iterator: TypeAlias = typing.Iterator + Generator: TypeAlias = typing.Generator generated_replacements: dict[type, str] = {} @@ -321,22 +318,15 @@ def yield_to_event_loop() -> None: pass -from google.cloud.bigtable.data._sync import transformers - if __name__ == "__main__": - import os import glob - import importlib - import inspect - import itertools - import black - import autoflake + from google.cloud.bigtable.data._sync import transformers # find all cross_sync decorated classes search_root = sys.argv[1] # cross_sync_classes = load_classes_from_dir(search_root)\ files = glob.glob(search_root + "/**/*.py", recursive=True) - artifacts = set() + artifacts: set[transformers.CrossSyncFileArtifact] = set() for file in files: converter = transformers.CrossSyncClassParser(file) converter.convert_file(artifacts) diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index c1fd434bd..fcff6187c 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -1,10 +1,25 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + import ast from dataclasses import dataclass, field class SymbolReplacer(ast.NodeTransformer): - def __init__(self, replacements): + def __init__(self, replacements:dict[str, str]): self.replacements = replacements def visit_Name(self, node): @@ -190,7 +205,7 @@ def render(self, with_black=True, save_to_disk=False) -> str: class CrossSyncClassParser(ast.NodeTransformer): def __init__(self, file_path): self.in_path = file_path - self._artifact_dict = {} + self._artifact_dict: dict[str, CrossSyncFileArtifact] = {} self.imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = [] self.cross_sync_converter = SymbolReplacer( {"CrossSync": "CrossSync._Sync_Impl"} @@ -207,7 +222,7 @@ def convert_file( self._artifact_dict = {f.file_path: f for f in artifacts or []} self.imports = self._get_imports(tree) self.visit(tree) - found = self._artifact_dict.values() + found = set(self._artifact_dict.values()) if artifacts is not None: artifacts.update(found) return found @@ -253,7 +268,7 @@ def visit_ClassDef(self, node): return node def _transform_class( - self, cls_ast: ast.ClassDef, new_name: str, replace_symbols=None, **kwargs + self, cls_ast: ast.ClassDef, new_name: str, replace_symbols:dict[str, str]|None=None, **kwargs ) -> ast.ClassDef: """ Transform async class into sync one, by running through a series of transformers diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index a197508c8..5f68d571c 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -57,16 +57,23 @@ BigtableDataClientAsync, ) else: - from google.api_core import grpc_helpers - from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + from google.api_core import grpc_helpers # noqa: F401 + from google.cloud.bigtable_v2.services.bigtable.client import ( # noqa: F401 + BigtableClient, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( # noqa: F401 PooledBigtableGrpcTransport, ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( # noqa: F401 PooledChannel, ) - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - from google.cloud.bigtable.data._sync.client import Table, BigtableDataClient + from google.cloud.bigtable.data._sync._read_rows import ( # noqa: F401 + _ReadRowsOperation, + ) + from google.cloud.bigtable.data._sync.client import ( # noqa: F401 + Table, + BigtableDataClient, + ) @CrossSync.sync_output( @@ -468,7 +475,6 @@ async def test__manage_channel_first_sleep( self, refresh_interval, wait_time, expected_sleep ): # first sleep time should be `refresh_interval` seconds after client init - import threading import time with mock.patch.object(time, "monotonic") as monotonic: diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 8ec1b67e9..1848b4300 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -28,10 +28,14 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if not CrossSync.is_async: + from .._async.test_read_rows_acceptance import ReadRowsTest # noqa: F401 + from .._async.test_read_rows_acceptance import TestFile # noqa: F401 + # TODO: autogenerate protos from # https://github.com/googleapis/conformance-tests/blob/main/bigtable/v2/proto/google/cloud/conformance/bigtable/v2/tests.proto -class ReadRowsTest(proto.Message): +class ReadRowsTest(proto.Message): # noqa: F811 class Result(proto.Message): row_key = proto.Field(proto.STRING, number=1) family_name = proto.Field(proto.STRING, number=2) @@ -48,16 +52,11 @@ class Result(proto.Message): results = proto.RepeatedField(proto.MESSAGE, number=3, message=Result) -class TestFile(proto.Message): +class TestFile(proto.Message): # noqa: F811 __test__ = False read_rows_tests = proto.RepeatedField(proto.MESSAGE, number=1, message=ReadRowsTest) -if not CrossSync.is_async: - from .._async.test_read_rows_acceptance import ReadRowsTest - from .._async.test_read_rows_acceptance import TestFile - - @CrossSync.sync_output( "tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", ) From ec63aa76771c2681123bab4239a0487074cd446b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 4 Jul 2024 00:44:36 -0600 Subject: [PATCH 131/360] got tests passing --- tests/unit/data/_async/test__mutate_rows.py | 16 +- tests/unit/data/_async/test__read_rows.py | 18 ++- .../data/_async/test_mutations_batcher.py | 141 ++++++++-------- .../data/_async/test_read_rows_acceptance.py | 21 ++- tests/unit/data/_sync/test__mutate_rows.py | 16 +- tests/unit/data/_sync/test__read_rows.py | 29 ++-- .../unit/data/_sync/test_mutations_batcher.py | 150 +++++++++--------- .../data/_sync/test_read_rows_acceptance.py | 22 ++- 8 files changed, 209 insertions(+), 204 deletions(-) diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 20a7d7e47..8743182a2 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -24,10 +24,8 @@ # try/except added for compatibility with python < 3.8 try: from unittest import mock - from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER import mock # type: ignore - from mock import AsyncMock # type: ignore @CrossSync.sync_output( @@ -51,7 +49,7 @@ def _target_class(self): def _make_one(self, *args, **kwargs): if not args: kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) - kwargs["table"] = kwargs.pop("table", AsyncMock()) + kwargs["table"] = kwargs.pop("table", CrossSync.Mock()) kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) @@ -76,7 +74,7 @@ async def _mock_stream(self, mutation_list, error_dict): ) def _make_mock_gapic(self, mutation_list, error_dict=None): - mock_fn = AsyncMock() + mock_fn = CrossSync.Mock() if error_dict is None: error_dict = {} mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( @@ -175,7 +173,7 @@ async def test_mutate_rows_operation(self): operation_timeout = 0.05 cls = self._target_class() with mock.patch( - f"{cls.__module__}.{cls.__name__}._run_attempt", AsyncMock() + f"{cls.__module__}.{cls.__name__}._run_attempt", CrossSync.Mock() ) as attempt_mock: instance = self._make_one( client, table, entries, operation_timeout, operation_timeout @@ -189,7 +187,7 @@ async def test_mutate_rows_attempt_exception(self, exc_type): """ exceptions raised from attempt should be raised in MutationsExceptionGroup """ - client = AsyncMock() + client = CrossSync.Mock() table = mock.Mock() entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 @@ -226,7 +224,7 @@ async def test_mutate_rows_exception(self, exc_type): with mock.patch.object( self._target_class(), "_run_attempt", - AsyncMock(), + CrossSync.Mock(), ) as attempt_mock: attempt_mock.side_effect = expected_cause found_exc = None @@ -263,7 +261,7 @@ async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): with mock.patch.object( self._target_class(), "_run_attempt", - AsyncMock(), + CrossSync.Mock(), ) as attempt_mock: attempt_mock.side_effect = [expected_cause] * num_retries + [None] instance = self._make_one( @@ -293,7 +291,7 @@ async def test_mutate_rows_incomplete_ignored(self): with mock.patch.object( self._target_class(), "_run_attempt", - AsyncMock(), + CrossSync.Mock(), ) as attempt_mock: attempt_mock.side_effect = _MutateRowsIncomplete("ignored") found_exc = None diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 130ecae7a..c3e201990 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -15,13 +15,16 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if CrossSync.is_async: + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync +else: + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + # try/except added for compatibility with python < 3.8 try: from unittest import mock - from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER import mock # type: ignore - from mock import AsyncMock # type: ignore # noqa F401 TEST_FAMILY = "family_name" TEST_QUALIFIER = b"qualifier" @@ -32,7 +35,7 @@ @CrossSync.sync_output( "tests.unit.data._sync.test__read_rows.TestReadRowsOperation", ) -class TestReadRowsOperation: +class TestReadRowsOperationAsync: """ Tests helper functions in the ReadRowsOperation class in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt @@ -40,9 +43,8 @@ class TestReadRowsOperation: """ @staticmethod + @CrossSync.convert(replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"}) def _get_target_class(): - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - return _ReadRowsOperationAsync def _make_one(self, *args, **kwargs): @@ -326,6 +328,7 @@ async def mock_stream(): assert "emit count exceeds row limit" in str(e.value) @CrossSync.pytest + @CrossSync.convert(sync_name="test_close", replace_symbols={"aclose": "close", "__anext__": "__next__"}) async def test_aclose(self): """ should be able to close a stream safely with aclose. @@ -346,15 +349,16 @@ async def mock_stream(): # read one row await gen.__anext__() await gen.aclose() - with pytest.raises(StopAsyncIteration): + with pytest.raises(CrossSync.StopIteration): await gen.__anext__() # try calling a second time await gen.aclose() # ensure close was propagated to wrapped generator - with pytest.raises(StopAsyncIteration): + with pytest.raises(CrossSync.StopIteration): await wrapped_gen.__anext__() @CrossSync.pytest + @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) async def test_retryable_ignore_repeated_rows(self): """ Duplicate rows should cause an invalid chunk error diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 21ede35bf..0a0aa8359 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -19,27 +19,32 @@ import google.api_core.retry from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.cloud.bigtable.data import TABLE_DEFAULT -from google.cloud.bigtable.data import TableAsync from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if CrossSync.is_async: + from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async.mutations_batcher import ( + _FlowControlAsync, + MutationsBatcherAsync, + ) +else: + from google.cloud.bigtable.data._sync.client import Table + from google.cloud.bigtable.data._sync.mutations_batcher import _FlowControl, MutationsBatcher + + # try/except added for compatibility with python < 3.8 try: from unittest import mock - from unittest.mock import AsyncMock except ImportError: # pragma: NO COVER import mock # type: ignore - from mock import AsyncMock # type: ignore @CrossSync.sync_output("tests.unit.data._sync.test_mutations_batcher.Test_FlowControl") class Test_FlowControl: @staticmethod + @CrossSync.convert(replace_symbols={"_FlowControlAsync": "_FlowControl"}) def _target_class(): - from google.cloud.bigtable.data._async.mutations_batcher import ( - _FlowControlAsync, - ) - return _FlowControlAsync def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): @@ -60,7 +65,7 @@ def test_ctor(self): assert instance._max_mutation_bytes == max_mutation_bytes assert instance._in_flight_mutation_count == 0 assert instance._in_flight_mutation_bytes == 0 - assert isinstance(instance._capacity_condition, asyncio.Condition) + assert isinstance(instance._capacity_condition, CrossSync.Condition) def test_ctor_invalid_values(self): """Test that values are positive, and fit within expected limits""" @@ -156,8 +161,6 @@ async def test_remove_from_flow_value_update( @CrossSync.pytest async def test__remove_from_flow_unlock(self): """capacity condition should notify after mutation is complete""" - import inspect - instance = self._make_one(10, 10) instance._in_flight_mutation_count = 10 instance._in_flight_mutation_bytes = 10 @@ -168,7 +171,7 @@ async def task_routine(): lambda: instance._has_capacity(1, 1) ) - if inspect.iscoroutinefunction(task_routine): + if CrossSync.is_async: # for async class, build task to test flow unlock task = asyncio.create_task(task_routine()) task_alive = lambda: not task.done() # noqa @@ -179,14 +182,14 @@ async def task_routine(): thread = threading.Thread(target=task_routine) thread.start() task_alive = thread.is_alive - await asyncio.sleep(0.05) + await CrossSync.sleep(0.05) # should be blocked due to capacity assert task_alive() is True # try changing size mutation = self._make_mutation(count=0, size=5) await instance.remove_from_flow([mutation]) - await asyncio.sleep(0.05) + await CrossSync.sleep(0.05) assert instance._in_flight_mutation_count == 10 assert instance._in_flight_mutation_bytes == 5 assert task_alive() is True @@ -194,7 +197,7 @@ async def task_routine(): instance._in_flight_mutation_bytes = 10 mutation = self._make_mutation(count=5, size=0) await instance.remove_from_flow([mutation]) - await asyncio.sleep(0.05) + await CrossSync.sleep(0.05) assert instance._in_flight_mutation_count == 5 assert instance._in_flight_mutation_bytes == 10 assert task_alive() is True @@ -202,7 +205,7 @@ async def task_routine(): instance._in_flight_mutation_count = 10 mutation = self._make_mutation(count=5, size=5) await instance.remove_from_flow([mutation]) - await asyncio.sleep(0.05) + await CrossSync.sleep(0.05) assert instance._in_flight_mutation_count == 5 assert instance._in_flight_mutation_bytes == 5 # task should be complete @@ -270,7 +273,7 @@ async def test_add_to_flow_max_mutation_limits( max_limit, ) sync_patch = mock.patch( - "google.cloud.bigtable.data.mutations._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + "google.cloud.bigtable.data._sync.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", max_limit, ) with async_patch, sync_patch: @@ -314,18 +317,11 @@ async def test_add_to_flow_oversize(self): "tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" ) class TestMutationsBatcherAsync: - def _get_target_class(self): - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) + @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) + def _get_target_class(self): return MutationsBatcherAsync - @staticmethod - def is_async(): - # helepr function for changing tests between sync and async versions - return True - def _make_one(self, table=None, **kwargs): from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import ServiceUnavailable @@ -351,7 +347,7 @@ def _make_mutation(count=1, size=1): @CrossSync.pytest async def test_ctor_defaults(self): with mock.patch.object( - self._get_target_class(), "_timer_routine", return_value=asyncio.Future() + self._get_target_class(), "_timer_routine", return_value=CrossSync.Future() ) as flush_timer_mock: table = mock.Mock() table.default_mutate_rows_operation_timeout = 10 @@ -383,16 +379,16 @@ async def test_ctor_defaults(self): instance._retryable_errors == table.default_mutate_rows_retryable_errors ) - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert flush_timer_mock.call_count == 1 assert flush_timer_mock.call_args[0][0] == 5 - assert isinstance(instance._flush_timer, asyncio.Future) + assert isinstance(instance._flush_timer, CrossSync.Future) @CrossSync.pytest async def test_ctor_explicit(self): """Test with explicit parameters""" with mock.patch.object( - self._get_target_class(), "_timer_routine", return_value=asyncio.Future() + self._get_target_class(), "_timer_routine", return_value=CrossSync.Future() ) as flush_timer_mock: table = mock.Mock() flush_interval = 20 @@ -435,16 +431,16 @@ async def test_ctor_explicit(self): assert instance._operation_timeout == operation_timeout assert instance._attempt_timeout == attempt_timeout assert instance._retryable_errors == retryable_errors - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert flush_timer_mock.call_count == 1 assert flush_timer_mock.call_args[0][0] == flush_interval - assert isinstance(instance._flush_timer, asyncio.Future) + assert isinstance(instance._flush_timer, CrossSync.Future) @CrossSync.pytest async def test_ctor_no_flush_limits(self): """Test with None for flush limits""" with mock.patch.object( - self._get_target_class(), "_timer_routine", return_value=asyncio.Future() + self._get_target_class(), "_timer_routine", return_value=CrossSync.Future() ) as flush_timer_mock: table = mock.Mock() table.default_mutate_rows_operation_timeout = 10 @@ -469,10 +465,10 @@ async def test_ctor_no_flush_limits(self): assert instance._flow_control._in_flight_mutation_count == 0 assert instance._flow_control._in_flight_mutation_bytes == 0 assert instance._entries_processed_since_last_raise == 0 - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert flush_timer_mock.call_count == 1 assert flush_timer_mock.call_args[0][0] is None - assert isinstance(instance._flush_timer, asyncio.Future) + assert isinstance(instance._flush_timer, CrossSync.Future) @CrossSync.pytest async def test_ctor_invalid_values(self): @@ -484,6 +480,7 @@ async def test_ctor_invalid_values(self): self._make_one(batch_attempt_timeout=-1) assert "attempt_timeout must be greater than 0" in str(e.value) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def test_default_argument_consistency(self): """ We supply default arguments in MutationsBatcherAsync.__init__, and in @@ -521,7 +518,7 @@ async def test__start_flush_timer_w_empty_input(self, input_val): ) as flush_mock: # mock different method depending on sync vs async async with self._make_one() as instance: - if self.is_async(): + if CrossSync.is_async: sleep_obj, sleep_method = asyncio, "wait_for" else: sleep_obj, sleep_method = instance._closed, "wait" @@ -544,7 +541,7 @@ async def test__start_flush_timer_call_when_closed( await instance.close() flush_mock.reset_mock() # mock different method depending on sync vs async - if self.is_async(): + if CrossSync.is_async: sleep_obj, sleep_method = asyncio, "wait_for" else: sleep_obj, sleep_method = instance._closed, "wait" @@ -573,8 +570,9 @@ async def test__flush_timer(self, num_staged): await self._get_target_class()._timer_routine( instance, expected_sleep ) - # replace with np-op so there are no issues on close - instance._flush_timer = asyncio.Future() + if CrossSync.is_async: + # replace with np-op so there are no issues on close + instance._flush_timer = CrossSync.Future() assert sleep_mock.call_count == loop_num + 1 sleep_kwargs = sleep_mock.call_args[1] assert sleep_kwargs["timeout"] == expected_sleep @@ -585,15 +583,12 @@ async def test__flush_timer_close(self): """Timer should continue terminate after close""" with mock.patch.object(self._get_target_class(), "_schedule_flush"): async with self._make_one() as instance: - with mock.patch("asyncio.sleep"): - # let task run in background - await asyncio.sleep(0.5) - assert instance._flush_timer.done() is False - # close the batcher - await instance.close() - await asyncio.sleep(0.1) - # task should be complete - assert instance._flush_timer.done() is True + # let task run in background + assert instance._flush_timer.done() is False + # close the batcher + await instance.close() + # task should be complete + assert instance._flush_timer.done() is True @CrossSync.pytest async def test_append_closed(self): @@ -741,11 +736,11 @@ async def test_flush_flow_control_concurrent_requests(self): fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] async with self._make_one(flow_control_max_mutation_count=1) as instance: with mock.patch.object( - instance, "_execute_mutate_rows", AsyncMock() + instance, "_execute_mutate_rows", CrossSync.Mock() ) as op_mock: # mock network calls async def mock_call(*args, **kwargs): - await asyncio.sleep(0.1) + await CrossSync.sleep(0.1) return [] op_mock.side_effect = mock_call @@ -753,13 +748,13 @@ async def mock_call(*args, **kwargs): # flush one large batch, that will be broken up into smaller batches instance._staged_entries = fake_mutations instance._schedule_flush() - await asyncio.sleep(0.01) + await CrossSync.sleep(0.01) # make room for new mutations for i in range(num_calls): await instance._flow_control.remove_from_flow( [self._make_mutation(count=1)] ) - await asyncio.sleep(0.01) + await CrossSync.sleep(0.01) # allow flushes to complete await instance._wait_for_batch_results(*instance._flush_jobs) duration = time.monotonic() - start_time @@ -784,7 +779,7 @@ async def test_schedule_flush_with_mutations(self): """if new mutations exist, should add a new flush task to _flush_jobs""" async with self._make_one() as instance: with mock.patch.object(instance, "_flush_internal") as flush_mock: - if not self.is_async(): + if not CrossSync.is_async: # simulate operation flush_mock.side_effect = lambda x: time.sleep(0.1) for i in range(1, 4): @@ -836,9 +831,9 @@ async def test_flush_clears_job_list(self): """ async with self._make_one() as instance: with mock.patch.object( - instance, "_flush_internal", AsyncMock() + instance, "_flush_internal", CrossSync.Mock() ) as flush_mock: - if not self.is_async(): + if not CrossSync.is_async: # simulate operation flush_mock.side_effect = lambda x: time.sleep(0.1) mutations = [self._make_mutation(count=1, size=1)] @@ -846,7 +841,7 @@ async def test_flush_clears_job_list(self): assert instance._flush_jobs == set() new_job = instance._schedule_flush() assert instance._flush_jobs == {new_job} - if self.is_async(): + if CrossSync.is_async: await new_job else: new_job.result() @@ -922,8 +917,8 @@ async def gen(num): @CrossSync.pytest async def test_timer_flush_end_to_end(self): """Flush should automatically trigger after flush_interval""" - num_nutations = 10 - mutations = [self._make_mutation(count=2, size=2)] * num_nutations + num_mutations = 10 + mutations = [self._make_mutation(count=2, size=2)] * num_mutations async with self._make_one(flush_interval=0.05) as instance: instance._table.default_operation_timeout = 10 @@ -932,23 +927,23 @@ async def test_timer_flush_end_to_end(self): instance._table.client._gapic_client, "mutate_rows" ) as gapic_mock: gapic_mock.side_effect = ( - lambda *args, **kwargs: self._mock_gapic_return(num_nutations) + lambda *args, **kwargs: self._mock_gapic_return(num_mutations) ) for m in mutations: await instance.append(m) assert instance._entries_processed_since_last_raise == 0 # let flush trigger due to timer - await asyncio.sleep(0.1) - assert instance._entries_processed_since_last_raise == num_nutations + await CrossSync.sleep(0.1) + assert instance._entries_processed_since_last_raise == num_mutations @CrossSync.pytest async def test__execute_mutate_rows(self): - if self.is_async(): + if CrossSync.is_async: mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" else: - mutate_path = "_sync._mutate_rows._MutateRowsOperation" + mutate_path = "_sync.mutations_batcher._MutateRowsOperation" with mock.patch(f"google.cloud.bigtable.data.{mutate_path}") as mutate_rows: - mutate_rows.return_value = AsyncMock() + mutate_rows.return_value = CrossSync.Mock() start_operation = mutate_rows().start table = mock.Mock() table.table_name = "test-table" @@ -976,10 +971,10 @@ async def test__execute_mutate_rows_returns_errors(self): FailedMutationEntryError, ) - if self.is_async(): + if CrossSync.is_async: mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" else: - mutate_path = "_sync._mutate_rows._MutateRowsOperation" + mutate_path = "_sync.mutations_batcher._MutateRowsOperation" with mock.patch( f"google.cloud.bigtable.data.{mutate_path}.start" ) as mutate_rows: @@ -1021,12 +1016,14 @@ async def test__raise_exceptions(self): instance._raise_exceptions() @CrossSync.pytest + @CrossSync.convert(sync_name="test___enter__", replace_symbols={"__aenter__": "__enter__"}) async def test___aenter__(self): """Should return self""" async with self._make_one() as instance: assert await instance.__aenter__() == instance @CrossSync.pytest + @CrossSync.convert(sync_name="test___exit__", replace_symbols={"__aexit__": "__exit__"}) async def test___aexit__(self): """aexit should call close""" async with self._make_one() as instance: @@ -1104,12 +1101,12 @@ async def test_timeout_args_passed(self): batch_operation_timeout and batch_attempt_timeout should be used in api calls """ - if self.is_async(): + if CrossSync.is_async: mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" else: - mutate_path = "_sync._mutate_rows._MutateRowsOperation" + mutate_path = "_sync.mutations_batcher._MutateRowsOperation" with mock.patch( - f"google.cloud.bigtable.data.{mutate_path}", return_value=AsyncMock() + f"google.cloud.bigtable.data.{mutate_path}", return_value=CrossSync.Mock() ) as mutate_rows: expected_operation_timeout = 17 expected_attempt_timeout = 13 @@ -1210,6 +1207,7 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): ([4], [core_exceptions.DeadlineExceeded]), ], ) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def test_customizable_retryable_errors( self, input_retryables, expected_retryables ): @@ -1217,15 +1215,10 @@ async def test_customizable_retryable_errors( Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - retryn_fn = ( - "google.cloud.bigtable.data._sync.cross_sync.CrossSync.retry_target" - if "Async" in self._get_target_class().__name__ - else "google.api_core.retry.retry_target" - ) with mock.patch.object( google.api_core.retry, "if_exception_type" ) as predicate_builder_mock: - with mock.patch(retryn_fn) as retry_fn_mock: + with mock.patch.object(CrossSync, "retry_target") as retry_fn_mock: table = None with mock.patch("asyncio.create_task"): table = TableAsync(mock.Mock(), "instance", "table") diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 1848b4300..9b00ad236 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -28,9 +28,14 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -if not CrossSync.is_async: +if CrossSync.is_async: + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + from google.cloud.bigtable.data._async.client import BigtableDataClientAsync +else: from .._async.test_read_rows_acceptance import ReadRowsTest # noqa: F401 from .._async.test_read_rows_acceptance import TestFile # noqa: F401 + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation # noqa: F401 + from google.cloud.bigtable.data._sync.client import BigtableDataClient # noqa: F401 # TODO: autogenerate protos from @@ -62,15 +67,13 @@ class TestFile(proto.Message): # noqa: F811 ) class TestReadRowsAcceptanceAsync: @staticmethod + @CrossSync.convert(replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"}) def _get_operation_class(): - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - return _ReadRowsOperationAsync @staticmethod + @CrossSync.convert(replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}) def _get_client_class(): - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - return BigtableDataClientAsync def parse_readrows_acceptance_tests(): @@ -168,12 +171,18 @@ def __init__(self, chunk_list): def __aiter__(self): return self + def __iter__(self): + return self + async def __anext__(self): self.idx += 1 if len(self.chunk_list) > self.idx: chunk = self.chunk_list[self.idx] return ReadRowsResponse(chunks=[chunk]) - raise StopAsyncIteration + raise CrossSync.StopIteration + + def __next__(self): + return self.__anext__() def cancel(self): pass diff --git a/tests/unit/data/_sync/test__mutate_rows.py b/tests/unit/data/_sync/test__mutate_rows.py index bcdd1103c..63d4009c6 100644 --- a/tests/unit/data/_sync/test__mutate_rows.py +++ b/tests/unit/data/_sync/test__mutate_rows.py @@ -22,10 +22,8 @@ try: from unittest import mock - from unittest.mock import AsyncMock except ImportError: import mock - from mock import AsyncMock class TestMutateRowsOperation: @@ -46,7 +44,7 @@ def _target_class(self): def _make_one(self, *args, **kwargs): if not args: kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) - kwargs["table"] = kwargs.pop("table", AsyncMock()) + kwargs["table"] = kwargs.pop("table", CrossSync._Sync_Impl.Mock()) kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) @@ -71,7 +69,7 @@ def _mock_stream(self, mutation_list, error_dict): ) def _make_mock_gapic(self, mutation_list, error_dict=None): - mock_fn = AsyncMock() + mock_fn = CrossSync._Sync_Impl.Mock() if error_dict is None: error_dict = {} mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( @@ -151,7 +149,7 @@ def test_mutate_rows_operation(self): operation_timeout = 0.05 cls = self._target_class() with mock.patch( - f"{cls.__module__}.{cls.__name__}._run_attempt", AsyncMock() + f"{cls.__module__}.{cls.__name__}._run_attempt", CrossSync._Sync_Impl.Mock() ) as attempt_mock: instance = self._make_one( client, table, entries, operation_timeout, operation_timeout @@ -162,7 +160,7 @@ def test_mutate_rows_operation(self): @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) def test_mutate_rows_attempt_exception(self, exc_type): """exceptions raised from attempt should be raised in MutationsExceptionGroup""" - client = AsyncMock() + client = CrossSync._Sync_Impl.Mock() table = mock.Mock() entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 @@ -194,7 +192,7 @@ def test_mutate_rows_exception(self, exc_type): operation_timeout = 0.05 expected_cause = exc_type("abort") with mock.patch.object( - self._target_class(), "_run_attempt", AsyncMock() + self._target_class(), "_run_attempt", CrossSync._Sync_Impl.Mock() ) as attempt_mock: attempt_mock.side_effect = expected_cause found_exc = None @@ -222,7 +220,7 @@ def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): expected_cause = exc_type("retry") num_retries = 2 with mock.patch.object( - self._target_class(), "_run_attempt", AsyncMock() + self._target_class(), "_run_attempt", CrossSync._Sync_Impl.Mock() ) as attempt_mock: attempt_mock.side_effect = [expected_cause] * num_retries + [None] instance = self._make_one( @@ -247,7 +245,7 @@ def test_mutate_rows_incomplete_ignored(self): entries = [self._make_mutation()] operation_timeout = 0.05 with mock.patch.object( - self._target_class(), "_run_attempt", AsyncMock() + self._target_class(), "_run_attempt", CrossSync._Sync_Impl.Mock() ) as attempt_mock: attempt_mock.side_effect = _MutateRowsIncomplete("ignored") found_exc = None diff --git a/tests/unit/data/_sync/test__read_rows.py b/tests/unit/data/_sync/test__read_rows.py index bdf0a42ce..296e8e7f9 100644 --- a/tests/unit/data/_sync/test__read_rows.py +++ b/tests/unit/data/_sync/test__read_rows.py @@ -14,7 +14,12 @@ # # This file is automatically generated by CrossSync. Do not edit manually. import pytest +from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if CrossSync._Sync_Impl.is_async: + pass +else: + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation try: from unittest import mock except ImportError: @@ -30,9 +35,7 @@ class TestReadRowsOperation: @staticmethod def _get_target_class(): - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - - return _ReadRowsOperationAsync + return _ReadRowsOperation def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) @@ -299,7 +302,7 @@ def mock_stream(): pass assert "emit count exceeds row limit" in str(e.value) - def test_aclose(self): + def test_close(self): """should be able to close a stream safely with aclose. Closed generators should raise StopAsyncIteration on next yield""" @@ -314,13 +317,13 @@ def mock_stream(): wrapped_gen = mock_stream() mock_attempt.return_value = wrapped_gen gen = instance.start_operation() - gen.__anext__() - gen.aclose() - with pytest.raises(StopAsyncIteration): - gen.__anext__() - gen.aclose() - with pytest.raises(StopAsyncIteration): - wrapped_gen.__anext__() + gen.__next__() + gen.close() + with pytest.raises(CrossSync._Sync_Impl.StopIteration): + gen.__next__() + gen.close() + with pytest.raises(CrossSync._Sync_Impl.StopIteration): + wrapped_gen.__next__() def test_retryable_ignore_repeated_rows(self): """Duplicate rows should cause an invalid chunk error""" @@ -351,7 +354,7 @@ def mock_stream(): stream = self._get_target_class().chunk_stream( instance, mock_awaitable_stream() ) - stream.__anext__() + stream.__next__() with pytest.raises(InvalidChunk) as exc: - stream.__anext__() + stream.__next__() assert "row keys should be strictly increasing" in str(exc.value) diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py index 20e160565..e176b54e9 100644 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -20,24 +20,26 @@ import google.api_core.retry from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.cloud.bigtable.data import TABLE_DEFAULT -from google.cloud.bigtable.data import TableAsync - +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +if CrossSync._Sync_Impl.is_async: + pass +else: + from google.cloud.bigtable.data._sync.client import Table + from google.cloud.bigtable.data._sync.mutations_batcher import ( + _FlowControl, + MutationsBatcher, + ) try: from unittest import mock - from unittest.mock import AsyncMock except ImportError: import mock - from mock import AsyncMock class Test_FlowControl: @staticmethod def _target_class(): - from google.cloud.bigtable.data._async.mutations_batcher import ( - _FlowControlAsync, - ) - - return _FlowControlAsync + return _FlowControl def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): return self._target_class()(max_mutation_count, max_mutation_bytes) @@ -57,7 +59,7 @@ def test_ctor(self): assert instance._max_mutation_bytes == max_mutation_bytes assert instance._in_flight_mutation_count == 0 assert instance._in_flight_mutation_bytes == 0 - assert isinstance(instance._capacity_condition, asyncio.Condition) + assert isinstance(instance._capacity_condition, CrossSync._Sync_Impl.Condition) def test_ctor_invalid_values(self): """Test that values are positive, and fit within expected limits""" @@ -145,8 +147,6 @@ def test_remove_from_flow_value_update( def test__remove_from_flow_unlock(self): """capacity condition should notify after mutation is complete""" - import inspect - instance = self._make_one(10, 10) instance._in_flight_mutation_count = 10 instance._in_flight_mutation_bytes = 10 @@ -157,7 +157,7 @@ def task_routine(): lambda: instance._has_capacity(1, 1) ) - if inspect.iscoroutinefunction(task_routine): + if CrossSync._Sync_Impl.is_async: task = asyncio.create_task(task_routine()) task_alive = lambda: not task.done() else: @@ -166,25 +166,25 @@ def task_routine(): thread = threading.Thread(target=task_routine) thread.start() task_alive = thread.is_alive - asyncio.sleep(0.05) + CrossSync._Sync_Impl.sleep(0.05) assert task_alive() is True mutation = self._make_mutation(count=0, size=5) instance.remove_from_flow([mutation]) - asyncio.sleep(0.05) + CrossSync._Sync_Impl.sleep(0.05) assert instance._in_flight_mutation_count == 10 assert instance._in_flight_mutation_bytes == 5 assert task_alive() is True instance._in_flight_mutation_bytes = 10 mutation = self._make_mutation(count=5, size=0) instance.remove_from_flow([mutation]) - asyncio.sleep(0.05) + CrossSync._Sync_Impl.sleep(0.05) assert instance._in_flight_mutation_count == 5 assert instance._in_flight_mutation_bytes == 10 assert task_alive() is True instance._in_flight_mutation_count = 10 mutation = self._make_mutation(count=5, size=5) instance.remove_from_flow([mutation]) - asyncio.sleep(0.05) + CrossSync._Sync_Impl.sleep(0.05) assert instance._in_flight_mutation_count == 5 assert instance._in_flight_mutation_bytes == 5 assert task_alive() is False @@ -237,7 +237,7 @@ def test_add_to_flow_max_mutation_limits( max_limit, ) sync_patch = mock.patch( - "google.cloud.bigtable.data.mutations._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + "google.cloud.bigtable.data._sync.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", max_limit, ) with async_patch, sync_patch: @@ -270,15 +270,7 @@ def test_add_to_flow_oversize(self): class TestMutationsBatcher: def _get_target_class(self): - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) - - return MutationsBatcherAsync - - @staticmethod - def is_async(): - return True + return MutationsBatcher def _make_one(self, table=None, **kwargs): from google.api_core.exceptions import DeadlineExceeded @@ -303,7 +295,9 @@ def _make_mutation(count=1, size=1): def test_ctor_defaults(self): with mock.patch.object( - self._get_target_class(), "_timer_routine", return_value=asyncio.Future() + self._get_target_class(), + "_timer_routine", + return_value=CrossSync._Sync_Impl.Future(), ) as flush_timer_mock: table = mock.Mock() table.default_mutate_rows_operation_timeout = 10 @@ -335,15 +329,17 @@ def test_ctor_defaults(self): instance._retryable_errors == table.default_mutate_rows_retryable_errors ) - asyncio.sleep(0) + CrossSync._Sync_Impl.yield_to_event_loop() assert flush_timer_mock.call_count == 1 assert flush_timer_mock.call_args[0][0] == 5 - assert isinstance(instance._flush_timer, asyncio.Future) + assert isinstance(instance._flush_timer, CrossSync._Sync_Impl.Future) def test_ctor_explicit(self): """Test with explicit parameters""" with mock.patch.object( - self._get_target_class(), "_timer_routine", return_value=asyncio.Future() + self._get_target_class(), + "_timer_routine", + return_value=CrossSync._Sync_Impl.Future(), ) as flush_timer_mock: table = mock.Mock() flush_interval = 20 @@ -386,15 +382,17 @@ def test_ctor_explicit(self): assert instance._operation_timeout == operation_timeout assert instance._attempt_timeout == attempt_timeout assert instance._retryable_errors == retryable_errors - asyncio.sleep(0) + CrossSync._Sync_Impl.yield_to_event_loop() assert flush_timer_mock.call_count == 1 assert flush_timer_mock.call_args[0][0] == flush_interval - assert isinstance(instance._flush_timer, asyncio.Future) + assert isinstance(instance._flush_timer, CrossSync._Sync_Impl.Future) def test_ctor_no_flush_limits(self): """Test with None for flush limits""" with mock.patch.object( - self._get_target_class(), "_timer_routine", return_value=asyncio.Future() + self._get_target_class(), + "_timer_routine", + return_value=CrossSync._Sync_Impl.Future(), ) as flush_timer_mock: table = mock.Mock() table.default_mutate_rows_operation_timeout = 10 @@ -419,10 +417,10 @@ def test_ctor_no_flush_limits(self): assert instance._flow_control._in_flight_mutation_count == 0 assert instance._flow_control._in_flight_mutation_bytes == 0 assert instance._entries_processed_since_last_raise == 0 - asyncio.sleep(0) + CrossSync._Sync_Impl.yield_to_event_loop() assert flush_timer_mock.call_count == 1 assert flush_timer_mock.call_args[0][0] is None - assert isinstance(instance._flush_timer, asyncio.Future) + assert isinstance(instance._flush_timer, CrossSync._Sync_Impl.Future) def test_ctor_invalid_values(self): """Test that timeout values are positive, and fit within expected limits""" @@ -440,7 +438,7 @@ def test_default_argument_consistency(self): import inspect get_batcher_signature = dict( - inspect.signature(TableAsync.mutations_batcher).parameters + inspect.signature(Table.mutations_batcher).parameters ) get_batcher_signature.pop("self") batcher_init_signature = dict( @@ -463,7 +461,7 @@ def test__start_flush_timer_w_empty_input(self, input_val): self._get_target_class(), "_schedule_flush" ) as flush_mock: with self._make_one() as instance: - if self.is_async(): + if CrossSync._Sync_Impl.is_async: sleep_obj, sleep_method = (asyncio, "wait_for") else: sleep_obj, sleep_method = (instance._closed, "wait") @@ -482,7 +480,7 @@ def test__start_flush_timer_call_when_closed(self): with self._make_one() as instance: instance.close() flush_mock.reset_mock() - if self.is_async(): + if CrossSync._Sync_Impl.is_async: sleep_obj, sleep_method = (asyncio, "wait_for") else: sleep_obj, sleep_method = (instance._closed, "wait") @@ -512,7 +510,8 @@ def test__flush_timer(self, num_staged): self._get_target_class()._timer_routine( instance, expected_sleep ) - instance._flush_timer = asyncio.Future() + if CrossSync._Sync_Impl.is_async: + instance._flush_timer = CrossSync._Sync_Impl.Future() assert sleep_mock.call_count == loop_num + 1 sleep_kwargs = sleep_mock.call_args[1] assert sleep_kwargs["timeout"] == expected_sleep @@ -522,12 +521,9 @@ def test__flush_timer_close(self): """Timer should continue terminate after close""" with mock.patch.object(self._get_target_class(), "_schedule_flush"): with self._make_one() as instance: - with mock.patch("asyncio.sleep"): - asyncio.sleep(0.5) - assert instance._flush_timer.done() is False - instance.close() - asyncio.sleep(0.1) - assert instance._flush_timer.done() is True + assert instance._flush_timer.done() is False + instance.close() + assert instance._flush_timer.done() is True def test_append_closed(self): """Should raise exception""" @@ -657,23 +653,23 @@ def test_flush_flow_control_concurrent_requests(self): fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] with self._make_one(flow_control_max_mutation_count=1) as instance: with mock.patch.object( - instance, "_execute_mutate_rows", AsyncMock() + instance, "_execute_mutate_rows", CrossSync._Sync_Impl.Mock() ) as op_mock: def mock_call(*args, **kwargs): - asyncio.sleep(0.1) + CrossSync._Sync_Impl.sleep(0.1) return [] op_mock.side_effect = mock_call start_time = time.monotonic() instance._staged_entries = fake_mutations instance._schedule_flush() - asyncio.sleep(0.01) + CrossSync._Sync_Impl.sleep(0.01) for i in range(num_calls): instance._flow_control.remove_from_flow( [self._make_mutation(count=1)] ) - asyncio.sleep(0.01) + CrossSync._Sync_Impl.sleep(0.01) instance._wait_for_batch_results(*instance._flush_jobs) duration = time.monotonic() - start_time assert len(instance._oldest_exceptions) == 0 @@ -694,7 +690,7 @@ def test_schedule_flush_with_mutations(self): """if new mutations exist, should add a new flush task to _flush_jobs""" with self._make_one() as instance: with mock.patch.object(instance, "_flush_internal") as flush_mock: - if not self.is_async(): + if not CrossSync._Sync_Impl.is_async: flush_mock.side_effect = lambda x: time.sleep(0.1) for i in range(1, 4): mutation = mock.Mock() @@ -738,16 +734,16 @@ def test_flush_clears_job_list(self): and removed when it completes""" with self._make_one() as instance: with mock.patch.object( - instance, "_flush_internal", AsyncMock() + instance, "_flush_internal", CrossSync._Sync_Impl.Mock() ) as flush_mock: - if not self.is_async(): + if not CrossSync._Sync_Impl.is_async: flush_mock.side_effect = lambda x: time.sleep(0.1) mutations = [self._make_mutation(count=1, size=1)] instance._staged_entries = mutations assert instance._flush_jobs == set() new_job = instance._schedule_flush() assert instance._flush_jobs == {new_job} - if self.is_async(): + if CrossSync._Sync_Impl.is_async: new_job else: new_job.result() @@ -817,8 +813,8 @@ def gen(num): def test_timer_flush_end_to_end(self): """Flush should automatically trigger after flush_interval""" - num_nutations = 10 - mutations = [self._make_mutation(count=2, size=2)] * num_nutations + num_mutations = 10 + mutations = [self._make_mutation(count=2, size=2)] * num_mutations with self._make_one(flush_interval=0.05) as instance: instance._table.default_operation_timeout = 10 instance._table.default_attempt_timeout = 9 @@ -826,21 +822,21 @@ def test_timer_flush_end_to_end(self): instance._table.client._gapic_client, "mutate_rows" ) as gapic_mock: gapic_mock.side_effect = ( - lambda *args, **kwargs: self._mock_gapic_return(num_nutations) + lambda *args, **kwargs: self._mock_gapic_return(num_mutations) ) for m in mutations: instance.append(m) assert instance._entries_processed_since_last_raise == 0 - asyncio.sleep(0.1) - assert instance._entries_processed_since_last_raise == num_nutations + CrossSync._Sync_Impl.sleep(0.1) + assert instance._entries_processed_since_last_raise == num_mutations def test__execute_mutate_rows(self): - if self.is_async(): + if CrossSync._Sync_Impl.is_async: mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" else: - mutate_path = "_sync._mutate_rows._MutateRowsOperation" + mutate_path = "_sync.mutations_batcher._MutateRowsOperation" with mock.patch(f"google.cloud.bigtable.data.{mutate_path}") as mutate_rows: - mutate_rows.return_value = AsyncMock() + mutate_rows.return_value = CrossSync._Sync_Impl.Mock() start_operation = mutate_rows().start table = mock.Mock() table.table_name = "test-table" @@ -867,10 +863,10 @@ def test__execute_mutate_rows_returns_errors(self): FailedMutationEntryError, ) - if self.is_async(): + if CrossSync._Sync_Impl.is_async: mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" else: - mutate_path = "_sync._mutate_rows._MutateRowsOperation" + mutate_path = "_sync.mutations_batcher._MutateRowsOperation" with mock.patch( f"google.cloud.bigtable.data.{mutate_path}.start" ) as mutate_rows: @@ -908,16 +904,16 @@ def test__raise_exceptions(self): instance._oldest_exceptions, instance._newest_exceptions = ([], []) instance._raise_exceptions() - def test___aenter__(self): + def test___enter__(self): """Should return self""" with self._make_one() as instance: - assert instance.__aenter__() == instance + assert instance.__enter__() == instance - def test___aexit__(self): + def test___exit__(self): """aexit should call close""" with self._make_one() as instance: with mock.patch.object(instance, "close") as close_mock: - instance.__aexit__(None, None, None) + instance.__exit__(None, None, None) assert close_mock.call_count == 1 def test_close(self): @@ -978,12 +974,13 @@ def test_atexit_registration(self): def test_timeout_args_passed(self): """batch_operation_timeout and batch_attempt_timeout should be used in api calls""" - if self.is_async(): + if CrossSync._Sync_Impl.is_async: mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" else: - mutate_path = "_sync._mutate_rows._MutateRowsOperation" + mutate_path = "_sync.mutations_batcher._MutateRowsOperation" with mock.patch( - f"google.cloud.bigtable.data.{mutate_path}", return_value=AsyncMock() + f"google.cloud.bigtable.data.{mutate_path}", + return_value=CrossSync._Sync_Impl.Mock(), ) as mutate_rows: expected_operation_timeout = 17 expected_attempt_timeout = 13 @@ -1079,18 +1076,15 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): def test_customizable_retryable_errors(self, input_retryables, expected_retryables): """Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer.""" - retryn_fn = ( - "google.cloud.bigtable.data._sync.cross_sync.CrossSync.retry_target" - if "Async" in self._get_target_class().__name__ - else "google.api_core.retry.retry_target" - ) with mock.patch.object( google.api_core.retry, "if_exception_type" ) as predicate_builder_mock: - with mock.patch(retryn_fn) as retry_fn_mock: + with mock.patch.object( + CrossSync._Sync_Impl, "retry_target" + ) as retry_fn_mock: table = None with mock.patch("asyncio.create_task"): - table = TableAsync(mock.Mock(), "instance", "table") + table = Table(mock.Mock(), "instance", "table") with self._make_one( table, batch_retryable_errors=input_retryables ) as instance: diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync/test_read_rows_acceptance.py index aa7725d5e..3553d5fbd 100644 --- a/tests/unit/data/_sync/test_read_rows_acceptance.py +++ b/tests/unit/data/_sync/test_read_rows_acceptance.py @@ -24,23 +24,23 @@ from google.cloud.bigtable.data.row import Row from google.cloud.bigtable.data._sync.cross_sync import CrossSync -if not CrossSync._Sync_Impl.is_async: +if CrossSync._Sync_Impl.is_async: + pass +else: from .._async.test_read_rows_acceptance import ReadRowsTest from .._async.test_read_rows_acceptance import TestFile + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + from google.cloud.bigtable.data._sync.client import BigtableDataClient class TestReadRowsAcceptance: @staticmethod def _get_operation_class(): - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - - return _ReadRowsOperationAsync + return _ReadRowsOperation @staticmethod def _get_client_class(): - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - return BigtableDataClientAsync + return BigtableDataClient def parse_readrows_acceptance_tests(): dirname = os.path.dirname(__file__) @@ -134,12 +134,18 @@ def __init__(self, chunk_list): def __aiter__(self): return self + def __iter__(self): + return self + def __anext__(self): self.idx += 1 if len(self.chunk_list) > self.idx: chunk = self.chunk_list[self.idx] return ReadRowsResponse(chunks=[chunk]) - raise StopAsyncIteration + raise CrossSync._Sync_Impl.StopIteration + + def __next__(self): + return self.__anext__() def cancel(self): pass From 5cdf5d91f68a3deb10b3a0fceca100591027ad15 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 4 Jul 2024 01:55:22 -0600 Subject: [PATCH 132/360] ran black --- .../cloud/bigtable/data/_sync/transformers.py | 8 +++++-- tests/unit/data/_async/test__read_rows.py | 13 ++++++++--- tests/unit/data/_async/test_client.py | 2 +- .../data/_async/test_mutations_batcher.py | 23 +++++++++++++------ .../data/_async/test_read_rows_acceptance.py | 12 +++++++--- tests/unit/data/_sync/test_client.py | 2 +- .../unit/data/_sync/test_mutations_batcher.py | 7 ++++-- 7 files changed, 48 insertions(+), 19 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index fcff6187c..b40c5bcb2 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -19,7 +19,7 @@ class SymbolReplacer(ast.NodeTransformer): - def __init__(self, replacements:dict[str, str]): + def __init__(self, replacements: dict[str, str]): self.replacements = replacements def visit_Name(self, node): @@ -268,7 +268,11 @@ def visit_ClassDef(self, node): return node def _transform_class( - self, cls_ast: ast.ClassDef, new_name: str, replace_symbols:dict[str, str]|None=None, **kwargs + self, + cls_ast: ast.ClassDef, + new_name: str, + replace_symbols: dict[str, str] | None = None, + **kwargs, ) -> ast.ClassDef: """ Transform async class into sync one, by running through a series of transformers diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index c3e201990..405e94e57 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -18,7 +18,9 @@ if CrossSync.is_async: from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync else: - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + from google.cloud.bigtable.data._sync._read_rows import ( # noqa: F401 + _ReadRowsOperation, + ) # try/except added for compatibility with python < 3.8 try: @@ -43,7 +45,9 @@ class TestReadRowsOperationAsync: """ @staticmethod - @CrossSync.convert(replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"}) + @CrossSync.convert( + replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} + ) def _get_target_class(): return _ReadRowsOperationAsync @@ -328,7 +332,10 @@ async def mock_stream(): assert "emit count exceeds row limit" in str(e.value) @CrossSync.pytest - @CrossSync.convert(sync_name="test_close", replace_symbols={"aclose": "close", "__anext__": "__next__"}) + @CrossSync.convert( + sync_name="test_close", + replace_symbols={"aclose": "close", "__anext__": "__next__"}, + ) async def test_aclose(self): """ should be able to close a stream safely with aclose. diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 5f68d571c..a2b468cb8 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1380,7 +1380,7 @@ async def test_customizable_retryable_errors( ) as retry_fn_mock: async with self._make_client() as client: table = client.get_table("instance-id", "table-id") - expected_predicate = lambda a: a in expected_retryables # noqa + expected_predicate = expected_retryables.__contains__ retry_fn_mock.side_effect = RuntimeError("stop early") with mock.patch( "google.api_core.retry.if_exception_type" diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 0a0aa8359..765ac4e13 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -29,8 +29,11 @@ MutationsBatcherAsync, ) else: - from google.cloud.bigtable.data._sync.client import Table - from google.cloud.bigtable.data._sync.mutations_batcher import _FlowControl, MutationsBatcher + from google.cloud.bigtable.data._sync.client import Table # noqa: F401 + from google.cloud.bigtable.data._sync.mutations_batcher import ( # noqa: F401 + _FlowControl, + MutationsBatcher, + ) # try/except added for compatibility with python < 3.8 @@ -174,7 +177,10 @@ async def task_routine(): if CrossSync.is_async: # for async class, build task to test flow unlock task = asyncio.create_task(task_routine()) - task_alive = lambda: not task.done() # noqa + + def task_alive(): + return not task.done() + else: # this branch will be tested in sync version of this test import threading @@ -317,7 +323,6 @@ async def test_add_to_flow_oversize(self): "tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" ) class TestMutationsBatcherAsync: - @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) def _get_target_class(self): return MutationsBatcherAsync @@ -1016,14 +1021,18 @@ async def test__raise_exceptions(self): instance._raise_exceptions() @CrossSync.pytest - @CrossSync.convert(sync_name="test___enter__", replace_symbols={"__aenter__": "__enter__"}) + @CrossSync.convert( + sync_name="test___enter__", replace_symbols={"__aenter__": "__enter__"} + ) async def test___aenter__(self): """Should return self""" async with self._make_one() as instance: assert await instance.__aenter__() == instance @CrossSync.pytest - @CrossSync.convert(sync_name="test___exit__", replace_symbols={"__aexit__": "__exit__"}) + @CrossSync.convert( + sync_name="test___exit__", replace_symbols={"__aexit__": "__exit__"} + ) async def test___aexit__(self): """aexit should call close""" async with self._make_one() as instance: @@ -1226,7 +1235,7 @@ async def test_customizable_retryable_errors( table, batch_retryable_errors=input_retryables ) as instance: assert instance._retryable_errors == expected_retryables - expected_predicate = lambda a: a in expected_retryables # noqa + expected_predicate = expected_retryables.__contains__ predicate_builder_mock.return_value = expected_predicate retry_fn_mock.side_effect = RuntimeError("stop early") mutation = self._make_mutation(count=1, size=1) diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 9b00ad236..4a9939abf 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -34,7 +34,9 @@ else: from .._async.test_read_rows_acceptance import ReadRowsTest # noqa: F401 from .._async.test_read_rows_acceptance import TestFile # noqa: F401 - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation # noqa: F401 + from google.cloud.bigtable.data._sync._read_rows import ( # noqa: F401 + _ReadRowsOperation, + ) from google.cloud.bigtable.data._sync.client import BigtableDataClient # noqa: F401 @@ -67,12 +69,16 @@ class TestFile(proto.Message): # noqa: F811 ) class TestReadRowsAcceptanceAsync: @staticmethod - @CrossSync.convert(replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"}) + @CrossSync.convert( + replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} + ) def _get_operation_class(): return _ReadRowsOperationAsync @staticmethod - @CrossSync.convert(replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}) + @CrossSync.convert( + replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} + ) def _get_client_class(): return BigtableDataClientAsync diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 86f235b79..e25fea3ac 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -1130,7 +1130,7 @@ def test_customizable_retryable_errors( ) as retry_fn_mock: with self._make_client() as client: table = client.get_table("instance-id", "table-id") - expected_predicate = lambda a: a in expected_retryables + expected_predicate = expected_retryables.__contains__ retry_fn_mock.side_effect = RuntimeError("stop early") with mock.patch( "google.api_core.retry.if_exception_type" diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py index e176b54e9..f044e09aa 100644 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -159,7 +159,10 @@ def task_routine(): if CrossSync._Sync_Impl.is_async: task = asyncio.create_task(task_routine()) - task_alive = lambda: not task.done() + + def task_alive(): + return not task.done() + else: import threading @@ -1089,7 +1092,7 @@ def test_customizable_retryable_errors(self, input_retryables, expected_retryabl table, batch_retryable_errors=input_retryables ) as instance: assert instance._retryable_errors == expected_retryables - expected_predicate = lambda a: a in expected_retryables + expected_predicate = expected_retryables.__contains__ predicate_builder_mock.return_value = expected_predicate retry_fn_mock.side_effect = RuntimeError("stop early") mutation = self._make_mutation(count=1, size=1) From b964c8d19e11cb15d1f9043d675e6cf0473d6e86 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 8 Jul 2024 12:37:36 -0600 Subject: [PATCH 133/360] use crossync for system tests --- .../cloud/bigtable/data/_sync/cross_sync.py | 16 + .../cloud/bigtable/data/_sync/transformers.py | 13 +- tests/system/data/test_system.py | 812 ++++++++++++++++++ tests/system/data/test_system_async.py | 129 +-- 4 files changed, 908 insertions(+), 62 deletions(-) create mode 100644 tests/system/data/test_system.py diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index e857543f0..c9e01fdc4 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -49,6 +49,7 @@ class CrossSync: sleep = asyncio.sleep retry_target = retries.retry_target_async retry_target_stream = retries.retry_target_stream_async + Retry = retries.AsyncRetry Queue: TypeAlias = asyncio.Queue Condition: TypeAlias = asyncio.Condition Future: TypeAlias = asyncio.Future @@ -106,6 +107,13 @@ def pytest(func): return pytest.mark.asyncio(func) + @staticmethod + def pytest_fixture(*args, **kwargs): + import pytest_asyncio + def decorator(func): + return pytest_asyncio.fixture(*args, **kwargs)(func) + return decorator + @staticmethod async def gather_partials( partial_list: Sequence[Callable[[], Awaitable[T]]], @@ -218,6 +226,7 @@ class _Sync_Impl: sleep = time.sleep retry_target = retries.retry_target retry_target_stream = retries.retry_target_stream + Retry = retries.Retry Queue: TypeAlias = queue.Queue Condition: TypeAlias = threading.Condition Future: TypeAlias = concurrent.futures.Future @@ -273,6 +282,13 @@ def event_wait( ) -> None: event.wait(timeout=timeout) + @staticmethod + def pytest_fixture(*args, **kwargs): + import pytest + def decorator(func): + return pytest.fixture(*args, **kwargs)(func) + return decorator + @staticmethod def gather_partials( partial_list: Sequence[Callable[[], T]], diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index b40c5bcb2..17edfebc6 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -130,6 +130,11 @@ def visit_FunctionDef(self, node): node = SymbolReplacer(replacements).visit(node) elif decorator_type == "pytest": pass + elif decorator_type == "pytest_fixture": + # keep decorator + node.decorator_list.append(decorator) + elif decorator_type == "Retry": + node.decorator_list.append(decorator) elif decorator_type == "drop_method": return None else: @@ -185,12 +190,12 @@ def render(self, with_black=True, save_to_disk=False) -> str: full_str += ( f'\n# mypy: disable-error-code="{",".join(self.mypy_ignore)}"\n\n' ) - full_str += "\n".join([ast.unparse(node) for node in self.imports]) + full_str += "\n".join([ast.unparse(node) for node in self.imports]) # type: ignore full_str += "\n\n" - full_str += "\n".join([ast.unparse(node) for node in self.converted_classes]) + full_str += "\n".join([ast.unparse(node) for node in self.converted_classes]) # type: ignore if with_black: - import black - import autoflake + import black # type: ignore + import autoflake # type: ignore full_str = black.format_str( autoflake.fix_code(full_str, remove_all_unused_imports=True), diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py new file mode 100644 index 000000000..46605cf4e --- /dev/null +++ b/tests/system/data/test_system.py @@ -0,0 +1,812 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by CrossSync. Do not edit manually. +import pytest +import uuid +import os +from google.api_core import retry +from google.api_core.exceptions import ClientError +from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE +from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +if CrossSync._Sync_Impl.is_async: + pass +else: + from google.cloud.bigtable.data._sync.client import BigtableDataClient + from .test_system_async import TEST_FAMILY, TEST_FAMILY_2 + + +class TempRowBuilder: + """ + Used to add rows to a table for testing purposes. + """ + + def __init__(self, table): + self.rows = [] + self.table = table + + def add_row( + self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" + ): + if isinstance(value, str): + value = value.encode("utf-8") + elif isinstance(value, int): + value = value.to_bytes(8, byteorder="big", signed=True) + request = { + "table_name": self.table.table_name, + "row_key": row_key, + "mutations": [ + { + "set_cell": { + "family_name": family, + "column_qualifier": qualifier, + "value": value, + } + } + ], + } + self.table.client._gapic_client.mutate_row(request) + self.rows.append(row_key) + + def delete_rows(self): + if self.rows: + request = { + "table_name": self.table.table_name, + "entries": [ + {"row_key": row, "mutations": [{"delete_from_row": {}}]} + for row in self.rows + ], + } + self.table.client._gapic_client.mutate_rows(request) + + +class TestSystem: + @CrossSync._Sync_Impl.pytest_fixture(scope="session") + def client(self): + project = os.getenv("GOOGLE_CLOUD_PROJECT") or None + with BigtableDataClient(project=project, pool_size=4) as client: + yield client + + @CrossSync._Sync_Impl.pytest_fixture(scope="session") + def table(self, client, table_id, instance_id): + with client.get_table(instance_id, table_id) as table: + yield table + + @pytest.fixture(scope="session") + def column_family_config(self): + """specify column families to create when creating a new test table""" + from google.cloud.bigtable_admin_v2 import types + + return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} + + @pytest.fixture(scope="session") + def init_table_id(self): + """The table_id to use when creating a new test table""" + return f"test-table-{uuid.uuid4().hex}" + + @pytest.fixture(scope="session") + def cluster_config(self, project_id): + """Configuration for the clusters to use when creating a new instance""" + from google.cloud.bigtable_admin_v2 import types + + cluster = { + "test-cluster": types.Cluster( + location=f"projects/{project_id}/locations/us-central1-b", serve_nodes=1 + ) + } + return cluster + + @pytest.mark.usefixtures("table") + def _retrieve_cell_value(self, table, row_key): + """Helper to read an individual row""" + from google.cloud.bigtable.data import ReadRowsQuery + + row_list = table.read_rows(ReadRowsQuery(row_keys=row_key)) + assert len(row_list) == 1 + row = row_list[0] + cell = row.cells[0] + return cell.value + + def _create_row_and_mutation( + self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" + ): + """Helper to create a new row, and a sample set_cell mutation to change its value""" + from google.cloud.bigtable.data.mutations import SetCell + + row_key = uuid.uuid4().hex.encode() + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row( + row_key, family=family, qualifier=qualifier, value=start_value + ) + assert self._retrieve_cell_value(table, row_key) == start_value + mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) + return (row_key, mutation) + + @CrossSync._Sync_Impl.pytest_fixture(scope="function") + def temp_rows(self, table): + builder = TempRowBuilder(table) + yield builder + builder.delete_rows() + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=10 + ) + def test_ping_and_warm_gapic(self, client, table): + """Simple ping rpc test + This test ensures channels are able to authenticate with backend""" + request = {"name": table.instance_name} + client._gapic_client.ping_and_warm(request) + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_ping_and_warm(self, client, table): + """Test ping and warm from handwritten client""" + try: + channel = client.transport._grpc_channel.pool[0] + except Exception: + channel = client.transport._grpc_channel + results = client._ping_and_warm_instances(channel) + assert len(results) == 1 + assert results[0] is None + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutation_set_cell(self, table, temp_rows): + """Ensure cells can be set properly""" + row_key = b"bulk_mutate" + new_value = uuid.uuid4().hex.encode() + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + table.mutate_row(row_key, mutation) + assert self._retrieve_cell_value(table, row_key) == new_value + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" + ) + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_sample_row_keys(self, client, table, temp_rows, column_split_config): + """Sample keys should return a single sample in small test tables""" + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + results = table.sample_row_keys() + assert len(results) == len(column_split_config) + 1 + for idx in range(len(column_split_config)): + assert results[idx][0] == column_split_config[idx] + assert isinstance(results[idx][1], int) + assert results[-1][0] == b"" + assert isinstance(results[-1][1], int) + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_bulk_mutations_set_cell(self, client, table, temp_rows): + """Ensure cells can be set properly""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + table.bulk_mutate_rows([bulk_mutation]) + assert self._retrieve_cell_value(table, row_key) == new_value + + def test_bulk_mutations_raise_exception(self, client, table): + """If an invalid mutation is passed, an exception should be raised""" + from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + row_key = uuid.uuid4().hex.encode() + mutation = SetCell( + family="nonexistent", qualifier=b"test-qualifier", new_value=b"" + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + with pytest.raises(MutationsExceptionGroup) as exc: + table.bulk_mutate_rows([bulk_mutation]) + assert len(exc.value.exceptions) == 1 + entry_error = exc.value.exceptions[0] + assert isinstance(entry_error, FailedMutationEntryError) + assert entry_error.index == 0 + assert entry_error.entry == bulk_mutation + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_context_manager(self, client, table, temp_rows): + """test batcher with context manager. Should flush on exit""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + row_key2, mutation2 = self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + with table.mutations_batcher() as batcher: + batcher.append(bulk_mutation) + batcher.append(bulk_mutation2) + assert self._retrieve_cell_value(table, row_key) == new_value + assert len(batcher._staged_entries) == 0 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_timer_flush(self, client, table, temp_rows): + """batch should occur after flush_interval seconds""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + flush_interval = 0.1 + with table.mutations_batcher(flush_interval=flush_interval) as batcher: + batcher.append(bulk_mutation) + CrossSync._Sync_Impl.yield_to_event_loop() + assert len(batcher._staged_entries) == 1 + CrossSync._Sync_Impl.sleep(flush_interval + 0.1) + assert len(batcher._staged_entries) == 0 + assert self._retrieve_cell_value(table, row_key) == new_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_count_flush(self, client, table, temp_rows): + """batch should flush after flush_limit_mutation_count mutations""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: + batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + assert len(batcher._staged_entries) == 1 + batcher.append(bulk_mutation2) + assert len(batcher._flush_jobs) == 1 + for future in list(batcher._flush_jobs): + future + future.result() + assert len(batcher._staged_entries) == 0 + assert len(batcher._flush_jobs) == 0 + assert self._retrieve_cell_value(table, row_key) == new_value + assert self._retrieve_cell_value(table, row_key2) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): + """batch should flush after flush_limit_bytes bytes""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 + with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: + batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + assert len(batcher._staged_entries) == 1 + batcher.append(bulk_mutation2) + assert len(batcher._flush_jobs) == 1 + assert len(batcher._staged_entries) == 0 + for future in list(batcher._flush_jobs): + future + future.result() + assert self._retrieve_cell_value(table, row_key) == new_value + assert self._retrieve_cell_value(table, row_key2) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_mutations_batcher_no_flush(self, client, table, temp_rows): + """test with no flush requirements met""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + start_value = b"unchanged" + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 + with table.mutations_batcher( + flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 + ) as batcher: + batcher.append(bulk_mutation) + assert len(batcher._staged_entries) == 1 + batcher.append(bulk_mutation2) + assert len(batcher._flush_jobs) == 0 + CrossSync._Sync_Impl.yield_to_event_loop() + assert len(batcher._staged_entries) == 2 + assert len(batcher._flush_jobs) == 0 + assert self._retrieve_cell_value(table, row_key) == start_value + assert self._retrieve_cell_value(table, row_key2) == start_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,increment,expected", + [ + (0, 0, 0), + (0, 1, 1), + (0, -1, -1), + (1, 0, 1), + (0, -100, -100), + (0, 3000, 3000), + (10, 4, 14), + (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), + (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), + (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), + ], + ) + def test_read_modify_write_row_increment( + self, client, table, temp_rows, start, increment, expected + ): + """test read_modify_write_row""" + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + rule = IncrementRule(family, qualifier, increment) + result = table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert int(result[0]) == expected + assert self._retrieve_cell_value(table, row_key) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,append,expected", + [ + (b"", b"", b""), + ("", "", b""), + (b"abc", b"123", b"abc123"), + (b"abc", "123", b"abc123"), + ("", b"1", b"1"), + (b"abc", "", b"abc"), + (b"hello", b"world", b"helloworld"), + ], + ) + def test_read_modify_write_row_append( + self, client, table, temp_rows, start, append, expected + ): + """test read_modify_write_row""" + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + rule = AppendValueRule(family, qualifier, append) + result = table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert result[0].value == expected + assert self._retrieve_cell_value(table, row_key) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_read_modify_write_row_chained(self, client, table, temp_rows): + """test read_modify_write_row with multiple rules""" + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + start_amount = 1 + increment_amount = 10 + temp_rows.add_row( + row_key, value=start_amount, family=family, qualifier=qualifier + ) + rule = [ + IncrementRule(family, qualifier, increment_amount), + AppendValueRule(family, qualifier, "hello"), + AppendValueRule(family, qualifier, "world"), + AppendValueRule(family, qualifier, "!"), + ] + result = table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert result[0].family == family + assert result[0].qualifier == qualifier + assert ( + result[0].value + == (start_amount + increment_amount).to_bytes(8, "big", signed=True) + + b"helloworld!" + ) + assert self._retrieve_cell_value(table, row_key) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start_val,predicate_range,expected_result", + [(1, (0, 2), True), (-1, (0, 2), False)], + ) + def test_check_and_mutate( + self, client, table, temp_rows, start_val, predicate_range, expected_result + ): + """test that check_and_mutate_row works applies the right mutations, and returns the right result""" + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row(row_key, value=start_val, family=family, qualifier=qualifier) + false_mutation_value = b"false-mutation-value" + false_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value + ) + true_mutation_value = b"true-mutation-value" + true_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value + ) + predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) + result = table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + assert result == expected_result + expected_value = ( + true_mutation_value if expected_result else false_mutation_value + ) + assert self._retrieve_cell_value(table, row_key) == expected_value + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_check_and_mutate_empty_request(self, client, table): + """check_and_mutate with no true or fale mutations should raise an error""" + from google.api_core import exceptions + + with pytest.raises(exceptions.InvalidArgument) as e: + table.check_and_mutate_row( + b"row_key", None, true_case_mutations=None, false_case_mutations=None + ) + assert "No mutations provided" in str(e.value) + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_stream(self, table, temp_rows): + """Ensure that the read_rows_stream method works""" + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + generator = table.read_rows_stream({}) + first_row = generator.__next__() + second_row = generator.__next__() + assert first_row.row_key == b"row_key_1" + assert second_row.row_key == b"row_key_2" + with pytest.raises(CrossSync._Sync_Impl.StopIteration): + generator.__next__() + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows(self, table, temp_rows): + """Ensure that the read_rows method works""" + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + row_list = table.read_rows({}) + assert len(row_list) == 2 + assert row_list[0].row_key == b"row_key_1" + assert row_list[1].row_key == b"row_key_2" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_sharded_simple(self, table, temp_rows): + """Test read rows sharded with two queries""" + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) + row_list = table.read_rows_sharded([query1, query2]) + assert len(row_list) == 4 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"b" + assert row_list[3].row_key == b"d" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_sharded_from_sample(self, table, temp_rows): + """Test end-to-end sharding""" + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.read_rows_query import RowRange + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + table_shard_keys = table.sample_row_keys() + query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) + shard_queries = query.shard(table_shard_keys) + row_list = table.read_rows_sharded(shard_queries) + assert len(row_list) == 3 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"d" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_sharded_filters_limits(self, table, temp_rows): + """Test read rows sharded with filters and limits""" + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + label_filter1 = ApplyLabelFilter("first") + label_filter2 = ApplyLabelFilter("second") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) + row_list = table.read_rows_sharded([query1, query2]) + assert len(row_list) == 3 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"b" + assert row_list[2].row_key == b"d" + assert row_list[0][0].labels == ["first"] + assert row_list[1][0].labels == ["second"] + assert row_list[2][0].labels == ["second"] + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_range_query(self, table, temp_rows): + """Ensure that the read_rows method works""" + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data import RowRange + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) + row_list = table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_single_key_query(self, table, temp_rows): + """Ensure that the read_rows method works with specified query""" + from google.cloud.bigtable.data import ReadRowsQuery + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + query = ReadRowsQuery(row_keys=[b"a", b"c"]) + row_list = table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_with_filter(self, table, temp_rows): + """ensure filters are applied""" + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + expected_label = "test-label" + row_filter = ApplyLabelFilter(expected_label) + query = ReadRowsQuery(row_filter=row_filter) + row_list = table.read_rows(query) + assert len(row_list) == 4 + for row in row_list: + assert row[0].labels == [expected_label] + + @pytest.mark.usefixtures("table") + def test_read_rows_stream_close(self, table, temp_rows): + """Ensure that the read_rows_stream can be closed""" + from google.cloud.bigtable.data import ReadRowsQuery + + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + query = ReadRowsQuery() + generator = table.read_rows_stream(query) + first_row = generator.__next__() + assert first_row.row_key == b"row_key_1" + generator.close() + with pytest.raises(CrossSync._Sync_Impl.StopIteration): + generator.__next__() + + @pytest.mark.usefixtures("table") + def test_read_row(self, table, temp_rows): + """Test read_row (single row helper)""" + from google.cloud.bigtable.data import Row + + temp_rows.add_row(b"row_key_1", value=b"value") + row = table.read_row(b"row_key_1") + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + def test_read_row_missing(self, table): + """Test read_row when row does not exist""" + from google.api_core import exceptions + + row_key = "row_key_not_exist" + result = table.read_row(row_key) + assert result is None + with pytest.raises(exceptions.InvalidArgument) as e: + table.read_row("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + def test_read_row_w_filter(self, table, temp_rows): + """Test read_row (single row helper)""" + from google.cloud.bigtable.data import Row + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + temp_rows.add_row(b"row_key_1", value=b"value") + expected_label = "test-label" + label_filter = ApplyLabelFilter(expected_label) + row = table.read_row(b"row_key_1", row_filter=label_filter) + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + assert row.cells[0].labels == [expected_label] + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + def test_row_exists(self, table, temp_rows): + from google.api_core import exceptions + + "Test row_exists with rows that exist and don't exist" + assert table.row_exists(b"row_key_1") is False + temp_rows.add_row(b"row_key_1") + assert table.row_exists(b"row_key_1") is True + assert table.row_exists("row_key_1") is True + assert table.row_exists(b"row_key_2") is False + assert table.row_exists("row_key_2") is False + assert table.row_exists("3") is False + temp_rows.add_row(b"3") + assert table.row_exists(b"3") is True + with pytest.raises(exceptions.InvalidArgument) as e: + table.row_exists("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @pytest.mark.parametrize( + "cell_value,filter_input,expect_match", + [ + (b"abc", b"abc", True), + (b"abc", "abc", True), + (b".", ".", True), + (".*", ".*", True), + (".*", b".*", True), + ("a", ".*", False), + (b".*", b".*", True), + ("\\a", "\\a", True), + (b"\xe2\x98\x83", "☃", True), + ("☃", "☃", True), + ("\\C☃", "\\C☃", True), + (1, 1, True), + (2, 1, False), + (68, 68, True), + ("D", 68, False), + (68, "D", False), + (-1, -1, True), + (2852126720, 2852126720, True), + (-1431655766, -1431655766, True), + (-1431655766, -1, False), + ], + ) + def test_literal_value_filter( + self, table, temp_rows, cell_value, filter_input, expect_match + ): + """Literal value filter does complex escaping on re2 strings. + Make sure inputs are properly interpreted by the server""" + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + from google.cloud.bigtable.data import ReadRowsQuery + + f = LiteralValueFilter(filter_input) + temp_rows.add_row(b"row_key_1", value=cell_value) + query = ReadRowsQuery(row_filter=f) + row_list = table.read_rows(query) + assert len(row_list) == bool( + expect_match + ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 69a23412e..2f259d6eb 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -23,10 +23,18 @@ from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +if CrossSync.is_async: + from google.cloud.bigtable.data._async.client import BigtableDataClientAsync +else: + from google.cloud.bigtable.data._sync.client import BigtableDataClient + from .test_system_async import TEST_FAMILY, TEST_FAMILY_2 + TEST_FAMILY = "test-family" TEST_FAMILY_2 = "test-family-2" - +@CrossSync.sync_output("tests.system.data.test_system.TempRowBuilder") class TempRowBuilderAsync: """ Used to add rows to a table for testing purposes. @@ -71,16 +79,17 @@ async def delete_rows(self): await self.table.client._gapic_client.mutate_rows(request) +@CrossSync.sync_output("tests.system.data.test_system.TestSystem") class TestSystemAsync: - @pytest_asyncio.fixture(scope="session") - async def client(self): - from google.cloud.bigtable.data import BigtableDataClientAsync + @CrossSync.convert(replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}) + @CrossSync.pytest_fixture(scope="session") + async def client(self): project = os.getenv("GOOGLE_CLOUD_PROJECT") or None async with BigtableDataClientAsync(project=project, pool_size=4) as client: yield client - @pytest_asyncio.fixture(scope="session") + @CrossSync.pytest_fixture(scope="session") async def table(self, client, table_id, instance_id): async with client.get_table( instance_id, @@ -88,6 +97,7 @@ async def table(self, client, table_id, instance_id): ) as table: yield table + @CrossSync.drop_method @pytest.fixture(scope="session") def event_loop(self): loop = asyncio.get_event_loop() @@ -159,7 +169,8 @@ async def _create_row_and_mutation( mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) return row_key, mutation - @pytest_asyncio.fixture(scope="function") + @CrossSync.convert(replace_symbols={"TempRowBuilderAsync": "TempRowBuilder"}) + @CrossSync.pytest_fixture(scope="function") async def temp_rows(self, table): builder = TempRowBuilderAsync(table) yield builder @@ -167,10 +178,10 @@ async def temp_rows(self, table): @pytest.mark.usefixtures("table") @pytest.mark.usefixtures("client") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=10 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_ping_and_warm_gapic(self, client, table): """ Simple ping rpc test @@ -181,10 +192,10 @@ async def test_ping_and_warm_gapic(self, client, table): @pytest.mark.usefixtures("table") @pytest.mark.usefixtures("client") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_ping_and_warm(self, client, table): """ Test ping and warm from handwritten client @@ -198,9 +209,9 @@ async def test_ping_and_warm(self, client, table): assert len(results) == 1 assert results[0] is None - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) async def test_mutation_set_cell(self, table, temp_rows): @@ -222,10 +233,10 @@ async def test_mutation_set_cell(self, table, temp_rows): ) @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys(self, client, table, temp_rows, column_split_config): """ Sample keys should return a single sample in small test tables @@ -245,7 +256,7 @@ async def test_sample_row_keys(self, client, table, temp_rows, column_split_conf @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("table") - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutations_set_cell(self, client, table, temp_rows): """ Ensure cells can be set properly @@ -263,7 +274,7 @@ async def test_bulk_mutations_set_cell(self, client, table, temp_rows): # ensure cell is updated assert (await self._retrieve_cell_value(table, row_key)) == new_value - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutations_raise_exception(self, client, table): """ If an invalid mutation is passed, an exception should be raised @@ -288,10 +299,10 @@ async def test_bulk_mutations_raise_exception(self, client, table): @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutations_batcher_context_manager(self, client, table, temp_rows): """ test batcher with context manager. Should flush on exit @@ -317,10 +328,10 @@ async def test_mutations_batcher_context_manager(self, client, table, temp_rows) @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutations_batcher_timer_flush(self, client, table, temp_rows): """ batch should occur after flush_interval seconds @@ -335,19 +346,19 @@ async def test_mutations_batcher_timer_flush(self, client, table, temp_rows): flush_interval = 0.1 async with table.mutations_batcher(flush_interval=flush_interval) as batcher: await batcher.append(bulk_mutation) - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert len(batcher._staged_entries) == 1 - await asyncio.sleep(flush_interval + 0.1) + await CrossSync.sleep(flush_interval + 0.1) assert len(batcher._staged_entries) == 0 # ensure cell is updated assert (await self._retrieve_cell_value(table, row_key)) == new_value @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutations_batcher_count_flush(self, client, table, temp_rows): """ batch should flush after flush_limit_mutation_count mutations @@ -385,10 +396,10 @@ async def test_mutations_batcher_count_flush(self, client, table, temp_rows): @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): """ batch should flush after flush_limit_bytes bytes @@ -426,7 +437,7 @@ async def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("table") - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutations_batcher_no_flush(self, client, table, temp_rows): """ test with no flush requirements met @@ -453,7 +464,7 @@ async def test_mutations_batcher_no_flush(self, client, table, temp_rows): await batcher.append(bulk_mutation2) # flush not scheduled assert len(batcher._flush_jobs) == 0 - await asyncio.sleep(0.01) + await CrossSync.yield_to_event_loop() assert len(batcher._staged_entries) == 2 assert len(batcher._flush_jobs) == 0 # ensure cells were not updated @@ -477,7 +488,7 @@ async def test_mutations_batcher_no_flush(self, client, table, temp_rows): (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_row_increment( self, client, table, temp_rows, start, increment, expected ): @@ -517,7 +528,7 @@ async def test_read_modify_write_row_increment( (b"hello", b"world", b"helloworld"), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_row_append( self, client, table, temp_rows, start, append, expected ): @@ -545,7 +556,7 @@ async def test_read_modify_write_row_append( @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("table") - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_row_chained(self, client, table, temp_rows): """ test read_modify_write_row with multiple rules @@ -589,7 +600,7 @@ async def test_read_modify_write_row_chained(self, client, table, temp_rows): (-1, (0, 2), False), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate( self, client, table, temp_rows, start_val, predicate_range, expected_result ): @@ -635,7 +646,7 @@ async def test_check_and_mutate( ) @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("table") - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate_empty_request(self, client, table): """ check_and_mutate with no true or fale mutations should raise an error @@ -649,10 +660,11 @@ async def test_check_and_mutate_empty_request(self, client, table): assert "No mutations provided" in str(e.value) @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_stream(self, table, temp_rows): """ Ensure that the read_rows_stream method works @@ -666,14 +678,14 @@ async def test_read_rows_stream(self, table, temp_rows): second_row = await generator.__anext__() assert first_row.row_key == b"row_key_1" assert second_row.row_key == b"row_key_2" - with pytest.raises(StopAsyncIteration): + with pytest.raises(CrossSync.StopIteration): await generator.__anext__() @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows(self, table, temp_rows): """ Ensure that the read_rows method works @@ -687,10 +699,10 @@ async def test_read_rows(self, table, temp_rows): assert row_list[1].row_key == b"row_key_2" @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_simple(self, table, temp_rows): """ Test read rows sharded with two queries @@ -711,10 +723,10 @@ async def test_read_rows_sharded_simple(self, table, temp_rows): assert row_list[3].row_key == b"d" @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_from_sample(self, table, temp_rows): """ Test end-to-end sharding @@ -737,10 +749,10 @@ async def test_read_rows_sharded_from_sample(self, table, temp_rows): assert row_list[2].row_key == b"d" @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_filters_limits(self, table, temp_rows): """ Test read rows sharded with filters and limits @@ -767,10 +779,10 @@ async def test_read_rows_sharded_filters_limits(self, table, temp_rows): assert row_list[2][0].labels == ["second"] @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_range_query(self, table, temp_rows): """ Ensure that the read_rows method works @@ -790,10 +802,10 @@ async def test_read_rows_range_query(self, table, temp_rows): assert row_list[1].row_key == b"c" @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_single_key_query(self, table, temp_rows): """ Ensure that the read_rows method works with specified query @@ -812,10 +824,10 @@ async def test_read_rows_single_key_query(self, table, temp_rows): assert row_list[1].row_key == b"c" @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_with_filter(self, table, temp_rows): """ ensure filters are applied @@ -837,7 +849,8 @@ async def test_read_rows_with_filter(self, table, temp_rows): assert row[0].labels == [expected_label] @pytest.mark.usefixtures("table") - @pytest.mark.asyncio + @CrossSync.convert(replace_symbols={"__anext__": "__next__", "aclose": "close"}) + @CrossSync.pytest async def test_read_rows_stream_close(self, table, temp_rows): """ Ensure that the read_rows_stream can be closed @@ -854,11 +867,11 @@ async def test_read_rows_stream_close(self, table, temp_rows): assert first_row.row_key == b"row_key_1" # close stream early await generator.aclose() - with pytest.raises(StopAsyncIteration): + with pytest.raises(CrossSync.StopIteration): await generator.__anext__() @pytest.mark.usefixtures("table") - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_row(self, table, temp_rows): """ Test read_row (single row helper) @@ -876,7 +889,7 @@ async def test_read_row(self, table, temp_rows): reason="emulator doesn't raise InvalidArgument", ) @pytest.mark.usefixtures("table") - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_row_missing(self, table): """ Test read_row when row does not exist @@ -891,7 +904,7 @@ async def test_read_row_missing(self, table): assert "Row keys must be non-empty" in str(e) @pytest.mark.usefixtures("table") - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_row_w_filter(self, table, temp_rows): """ Test read_row (single row helper) @@ -913,7 +926,7 @@ async def test_read_row_w_filter(self, table, temp_rows): reason="emulator doesn't raise InvalidArgument", ) @pytest.mark.usefixtures("table") - @pytest.mark.asyncio + @CrossSync.pytest async def test_row_exists(self, table, temp_rows): from google.api_core import exceptions @@ -932,7 +945,7 @@ async def test_row_exists(self, table, temp_rows): assert "Row keys must be non-empty" in str(e) @pytest.mark.usefixtures("table") - @retry.AsyncRetry( + @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @pytest.mark.parametrize( @@ -960,7 +973,7 @@ async def test_row_exists(self, table, temp_rows): (-1431655766, -1, False), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_literal_value_filter( self, table, temp_rows, cell_value, filter_input, expect_match ): From a9689057e086225d9aaef4ed522836a2a951d69e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 8 Jul 2024 12:50:24 -0600 Subject: [PATCH 134/360] don't manually clear channel refresh list --- google/cloud/bigtable/data/_async/client.py | 3 +-- google/cloud/bigtable/data/_sync/client.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 307b41775..8ee92f2cf 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -263,7 +263,7 @@ def _start_background_channel_refresh(self) -> None: ) self._channel_refresh_tasks.append(refresh_task) refresh_task.add_done_callback( - lambda _: self._channel_refresh_tasks.remove(refresh_task) + lambda _: self._channel_refresh_tasks.remove(refresh_task) if refresh_task in self._channel_refresh_tasks else None ) async def close(self, timeout: float | None = None): @@ -277,7 +277,6 @@ async def close(self, timeout: float | None = None): if self._executor: self._executor.shutdown(wait=False) await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) - self._channel_refresh_tasks = [] async def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 9e75637d3..833573365 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -194,6 +194,8 @@ def _start_background_channel_refresh(self) -> None: self._channel_refresh_tasks.append(refresh_task) refresh_task.add_done_callback( lambda _: self._channel_refresh_tasks.remove(refresh_task) + if refresh_task in self._channel_refresh_tasks + else None ) def close(self, timeout: float | None = None): @@ -205,7 +207,6 @@ def close(self, timeout: float | None = None): if self._executor: self._executor.shutdown(wait=False) CrossSync._Sync_Impl.wait(self._channel_refresh_tasks, timeout=timeout) - self._channel_refresh_tasks = [] def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None From 95c30f88f51e0c1617e4e68275384c7e264222dd Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 8 Jul 2024 13:48:13 -0600 Subject: [PATCH 135/360] mark each method to convert --- .../bigtable/data/_async/_mutate_rows.py | 2 + .../cloud/bigtable/data/_async/_read_rows.py | 1 + google/cloud/bigtable/data/_async/client.py | 12 ++++ .../bigtable/data/_async/mutations_batcher.py | 7 ++ .../cloud/bigtable/data/_sync/_mutate_rows.py | 16 +++-- .../cloud/bigtable/data/_sync/_read_rows.py | 10 +-- google/cloud/bigtable/data/_sync/client.py | 65 +++++++++++++------ .../cloud/bigtable/data/_sync/cross_sync.py | 6 +- .../bigtable/data/_sync/mutations_batcher.py | 43 ++++++++---- .../cloud/bigtable/data/_sync/transformers.py | 9 ++- 10 files changed, 123 insertions(+), 48 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 4e4e6b491..fee05fc39 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -141,6 +141,7 @@ def __init__( self.remaining_indices = list(range(len(self.mutations))) self.errors: dict[int, list[Exception]] = {} + @CrossSync.convert async def start(self): """ Start the operation, and run until completion @@ -173,6 +174,7 @@ async def start(self): if all_errors: raise MutationsExceptionGroup(all_errors, len(self.mutations)) + @CrossSync.convert async def _run_attempt(self): """ Run a single attempt of the mutate_rows rpc. diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 15dd2b7bf..b375e9d03 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -166,6 +166,7 @@ def _read_rows_attempt(self) -> CrossSync.Iterable[Row]: chunked_stream = self.chunk_stream(gapic_stream) return self.merge_rows(chunked_stream) + @CrossSync.convert async def chunk_stream( self, stream: CrossSync.Awaitable[CrossSync.Iterable[ReadRowsResponsePB]] ) -> CrossSync.Iterable[ReadRowsResponsePB.CellChunk]: diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 8ee92f2cf..237d66fbf 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -266,6 +266,7 @@ def _start_background_channel_refresh(self) -> None: lambda _: self._channel_refresh_tasks.remove(refresh_task) if refresh_task in self._channel_refresh_tasks else None ) + @CrossSync.convert async def close(self, timeout: float | None = None): """ Cancel all background tasks @@ -278,6 +279,7 @@ async def close(self, timeout: float | None = None): self._executor.shutdown(wait=False) await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) + @CrossSync.convert async def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None ) -> list[BaseException | None]: @@ -319,6 +321,7 @@ async def _ping_and_warm_instances( ) return [r or None for r in result_list] + @CrossSync.convert async def _manage_channel( self, channel_idx: int, @@ -678,6 +681,7 @@ async def read_rows_stream( ) return row_merger.start_operation() + @CrossSync.convert async def read_rows( self, query: ReadRowsQuery, @@ -725,6 +729,7 @@ async def read_rows( ) return [row async for row in row_generator] + @CrossSync.convert async def read_row( self, row_key: str | bytes, @@ -774,6 +779,7 @@ async def read_row( return None return results[0] + @CrossSync.convert async def read_rows_sharded( self, sharded_query: ShardedQuery, @@ -873,6 +879,7 @@ async def read_rows_with_semaphore(query): ) return results_list + @CrossSync.convert async def row_exists( self, row_key: str | bytes, @@ -921,6 +928,7 @@ async def row_exists( ) return len(results) > 0 + @CrossSync.convert async def sample_row_keys( self, *, @@ -1043,6 +1051,7 @@ def mutations_batcher( batch_retryable_errors=batch_retryable_errors, ) + @CrossSync.convert async def mutate_row( self, row_key: str | bytes, @@ -1179,6 +1188,7 @@ async def bulk_mutate_rows( ) await operation.start() + @CrossSync.convert async def check_and_mutate_row( self, row_key: str | bytes, @@ -1245,6 +1255,7 @@ async def check_and_mutate_row( ) return result.predicate_matched + @CrossSync.convert async def read_modify_write_row( self, row_key: str | bytes, @@ -1295,6 +1306,7 @@ async def read_modify_write_row( # construct Row from result return Row._from_pb(result.row) + @CrossSync.convert async def close(self): """ Called to close the Table instance and release any resources held by it. diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index b7e55e9e1..4a83683dd 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -110,6 +110,7 @@ def _has_capacity(self, additional_count: int, additional_size: int) -> bool: new_count = self._in_flight_mutation_count + additional_count return new_size <= acceptable_size and new_count <= acceptable_count + @CrossSync.convert async def remove_from_flow( self, mutations: RowMutationEntry | list[RowMutationEntry] ) -> None: @@ -131,6 +132,7 @@ async def remove_from_flow( async with self._capacity_condition: self._capacity_condition.notify_all() + @CrossSync.convert async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): """ Generator function that registers mutations with flow control. As mutations @@ -273,6 +275,7 @@ def __init__( # clean up on program exit atexit.register(self._on_exit) + @CrossSync.convert async def _timer_routine(self, interval: float | None) -> None: """ Set up a background task to flush the batcher every interval seconds @@ -293,6 +296,7 @@ async def _timer_routine(self, interval: float | None) -> None: if not self._closed.is_set() and self._staged_entries: self._schedule_flush() + @CrossSync.convert async def append(self, mutation_entry: RowMutationEntry): """ Add a new set of mutations to the internal queue @@ -342,6 +346,7 @@ def _schedule_flush(self) -> CrossSync.Future[None] | None: return new_task return None + @CrossSync.convert async def _flush_internal(self, new_entries: list[RowMutationEntry]): """ Flushes a set of mutations to the server, and updates internal state @@ -467,6 +472,7 @@ def closed(self) -> bool: """ return self._closed.is_set() + @CrossSync.convert async def close(self): """ Flush queue and clean up resources @@ -494,6 +500,7 @@ def _on_exit(self): ) @staticmethod + @CrossSync.convert async def _wait_for_batch_results( *tasks: CrossSync.Future[list[FailedMutationEntryError]] | CrossSync.Future[None], diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index 11591c007..060e5184f 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -28,6 +28,8 @@ from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if not CrossSync._Sync_Impl.is_async: + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -36,8 +38,6 @@ else: from google.cloud.bigtable.data._sync.client import Table from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient -if not CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto class _MutateRowsOperation: @@ -100,10 +100,12 @@ def __init__( self.errors: dict[int, list[Exception]] = {} def start(self): - """Start the operation, and run until completion + """ + Start the operation, and run until completion Raises: - MutationsExceptionGroup: if any mutations failed""" + MutationsExceptionGroup: if any mutations failed + """ try: self._operation() except Exception as exc: @@ -127,12 +129,14 @@ def start(self): raise MutationsExceptionGroup(all_errors, len(self.mutations)) def _run_attempt(self): - """Run a single attempt of the mutate_rows rpc. + """ + Run a single attempt of the mutate_rows rpc. Raises: _MutateRowsIncomplete: if there are failed mutations eligible for retry after the attempt is complete - GoogleAPICallError: if the gapic rpc fails""" + GoogleAPICallError: if the gapic rpc fails + """ request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] active_request_indices = { req_idx: orig_idx for req_idx, orig_idx in enumerate(self.remaining_indices) diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index 7a6d1300c..d74c1a43f 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -28,13 +28,13 @@ from google.api_core.retry import exponential_sleep_generator from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if not CrossSync._Sync_Impl.is_async: + from google.cloud.bigtable.data._async._read_rows import _ResetRow if TYPE_CHECKING: if CrossSync._Sync_Impl.is_async: pass else: from google.cloud.bigtable.data._sync.client import Table -if not CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable.data._async._read_rows import _ResetRow class _ReadRowsOperation: @@ -142,12 +142,14 @@ def chunk_stream( CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB] ], ) -> CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB.CellChunk]: - """process chunks out of raw read_rows stream + """ + process chunks out of raw read_rows stream Args: stream: the raw read_rows stream from the gapic client Yields: - ReadRowsResponsePB.CellChunk: the next chunk in the stream""" + ReadRowsResponsePB.CellChunk: the next chunk in the stream + """ for resp in stream: resp = resp._pb if resp.last_scanned_row_key: diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 833573365..ef6532489 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -199,7 +199,9 @@ def _start_background_channel_refresh(self) -> None: ) def close(self, timeout: float | None = None): - """Cancel all background tasks""" + """ + Cancel all background tasks + """ self._is_closed.set() for task in self._channel_refresh_tasks: task.cancel() @@ -211,7 +213,8 @@ def close(self, timeout: float | None = None): def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None ) -> list[BaseException | None]: - """Prepares the backend for requests on a channel + """ + Prepares the backend for requests on a channel Pings each Bigtable instance registered in `_active_instances` on the client @@ -254,7 +257,8 @@ def _manage_channel( refresh_interval_max: float = 60 * 45, grace_period: float = 60 * 10, ) -> None: - """Background coroutine that periodically refreshes and warms a grpc channel + """ + Background coroutine that periodically refreshes and warms a grpc channel The backend will automatically close channels after 60 minutes, so `refresh_interval` + `grace_period` should be < 60 minutes @@ -270,7 +274,8 @@ def _manage_channel( process in seconds. Actual interval will be a random value between `refresh_interval_min` and `refresh_interval_max` grace_period: time to allow previous channel to serve existing - requests before closing, in seconds""" + requests before closing, in seconds + """ first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) @@ -574,7 +579,8 @@ def read_rows( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: - """Read a set of rows from the table, based on the specified query. + """ + Read a set of rows from the table, based on the specified query. Retruns results as a list of Row objects when the request is complete. For streamed results, use read_rows_stream. @@ -621,7 +627,8 @@ def read_row( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> Row | None: - """Read a single row from the table, based on the specified key. + """ + Read a single row from the table, based on the specified key. Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. @@ -668,7 +675,8 @@ def read_rows_sharded( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: - """Runs a sharded query in parallel, then return the results in a single list. + """ + Runs a sharded query in parallel, then return the results in a single list. Results will be returned in the order of the input queries. This function is intended to be run on the results on a query.shard() call. @@ -695,7 +703,8 @@ def read_rows_sharded( list[Row]: a list of Rows returned by the query Raises: ShardedReadRowsExceptionGroup: if any of the queries failed - ValueError: if the query_list is empty""" + ValueError: if the query_list is empty + """ if not sharded_query: raise ValueError("empty sharded_query") operation_timeout, attempt_timeout = _helpers._get_timeouts( @@ -757,7 +766,8 @@ def row_exists( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> bool: - """Return a boolean indicating whether the specified row exists in the table. + """ + Return a boolean indicating whether the specified row exists in the table. uses the filters: chain(limit cells per row = 1, strip value) Args: @@ -802,7 +812,8 @@ def sample_row_keys( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> RowKeySamples: - """Return a set of RowKeySamples that delimit contiguous sections of the table of + """ + Return a set of RowKeySamples that delimit contiguous sections of the table of approximately equal size RowKeySamples output can be used with ReadRowsQuery.shard() to create a sharded query that @@ -918,7 +929,8 @@ def mutate_row( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ): - """Mutates a row atomically. + """ + Mutates a row atomically. Cells already present in the row are left unchanged unless explicitly changed by ``mutation``. @@ -946,7 +958,8 @@ def mutate_row( GoogleAPIError exceptions from any retries that failed google.api_core.exceptions.GoogleAPIError: raised on non-idempotent operations that cannot be safely retried. - ValueError: if invalid arguments are provided""" + ValueError: if invalid arguments are provided + """ operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) @@ -1039,7 +1052,8 @@ def check_and_mutate_row( false_case_mutations: Mutation | list[Mutation] | None = None, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> bool: - """Mutates a row atomically based on the output of a predicate filter + """ + Mutates a row atomically based on the output of a predicate filter Non-idempotent operation: will not be retried @@ -1068,7 +1082,8 @@ def check_and_mutate_row( Returns: bool indicating whether the predicate was true or false Raises: - google.api_core.exceptions.GoogleAPIError: exceptions from grpc call""" + google.api_core.exceptions.GoogleAPIError: exceptions from grpc call + """ operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) if true_case_mutations is not None and ( not isinstance(true_case_mutations, list) @@ -1101,7 +1116,8 @@ def read_modify_write_row( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> Row: - """Reads and modifies a row atomically according to input ReadModifyWriteRules, + """ + Reads and modifies a row atomically according to input ReadModifyWriteRules, and returns the contents of all modified cells The new value for the timestamp is the greater of the existing timestamp or @@ -1121,7 +1137,8 @@ def read_modify_write_row( Row: a Row containing cell data that was modified as part of the operation Raises: google.api_core.exceptions.GoogleAPIError: exceptions from grpc call - ValueError: if invalid arguments are provided""" + ValueError: if invalid arguments are provided + """ operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) if operation_timeout <= 0: raise ValueError("operation_timeout must be greater than 0") @@ -1142,23 +1159,29 @@ def read_modify_write_row( return Row._from_pb(result.row) def close(self): - """Called to close the Table instance and release any resources held by it.""" + """ + Called to close the Table instance and release any resources held by it. + """ if self._register_instance_future: self._register_instance_future.cancel() self.client._remove_instance_registration(self.instance_id, self) def __enter__(self): - """Implement async context manager protocol + """ + Implement async context manager protocol Ensure registration task has time to run, so that - grpc channels will be warmed for the specified instance""" + grpc channels will be warmed for the specified instance + """ if self._register_instance_future: self._register_instance_future return self def __exit__(self, exc_type, exc_val, exc_tb): - """Implement async context manager protocol + """ + Implement async context manager protocol Unregister this instance with the client, so that - grpc channels will no longer be warmed""" + grpc channels will no longer be warmed + """ self.close() diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index c9e01fdc4..da5e6450a 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -67,8 +67,12 @@ class CrossSync: @staticmethod def convert( - *, sync_name: str | None = None, replace_symbols: dict[str, str] | None = None + *args, sync_name: str | None = None, replace_symbols: dict[str, str] | None = None ): + if args: + # only positional argument should be the function to wrap. Return it directly + return args[0] + def decorator(func): return func diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 006982c1f..46274980b 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -97,12 +97,14 @@ def _has_capacity(self, additional_count: int, additional_size: int) -> bool: def remove_from_flow( self, mutations: RowMutationEntry | list[RowMutationEntry] ) -> None: - """Removes mutations from flow control. This method should be called once + """ + Removes mutations from flow control. This method should be called once for each mutation that was sent to add_to_flow, after the corresponding operation is complete. Args: - mutations: mutation or list of mutations to remove from flow control""" + mutations: mutation or list of mutations to remove from flow control + """ if not isinstance(mutations, list): mutations = [mutations] total_count = sum((len(entry.mutations) for entry in mutations)) @@ -113,7 +115,8 @@ def remove_from_flow( self._capacity_condition.notify_all() def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): - """Generator function that registers mutations with flow control. As mutations + """ + Generator function that registers mutations with flow control. As mutations are accepted into the flow control, they are yielded back to the caller, to be sent in a batch. If the flow control is at capacity, the generator will block until there is capacity available. @@ -123,7 +126,8 @@ def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): Yields: list[RowMutationEntry]: list of mutations that have reserved space in the flow control. - Each batch contains at least one mutation.""" + Each batch contains at least one mutation. + """ if not isinstance(mutations, list): mutations = [mutations] start_idx = 0 @@ -237,13 +241,15 @@ def __init__( atexit.register(self._on_exit) def _timer_routine(self, interval: float | None) -> None: - """Set up a background task to flush the batcher every interval seconds + """ + Set up a background task to flush the batcher every interval seconds If interval is None, an empty future is returned Args: flush_interval: Automatically flush every flush_interval seconds. - If None, no time-based flushing is performed.""" + If None, no time-based flushing is performed. + """ if not interval or interval <= 0: return None while not self._closed.is_set(): @@ -254,13 +260,15 @@ def _timer_routine(self, interval: float | None) -> None: self._schedule_flush() def append(self, mutation_entry: RowMutationEntry): - """Add a new set of mutations to the internal queue + """ + Add a new set of mutations to the internal queue Args: mutation_entry: new entry to add to flush queue Raises: RuntimeError: if batcher is closed - ValueError: if an invalid mutation type is added""" + ValueError: if an invalid mutation type is added + """ if self._closed.is_set(): raise RuntimeError("Cannot append to closed MutationsBatcher") if isinstance(mutation_entry, Mutation): @@ -296,10 +304,12 @@ def _schedule_flush(self) -> CrossSync._Sync_Impl.Future[None] | None: return None def _flush_internal(self, new_entries: list[RowMutationEntry]): - """Flushes a set of mutations to the server, and updates internal state + """ + Flushes a set of mutations to the server, and updates internal state Args: - new_entries list of RowMutationEntry objects to flush""" + new_entries list of RowMutationEntry objects to flush + """ in_process_requests: list[ CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] ] = [] @@ -387,9 +397,11 @@ def __enter__(self): return self def __exit__(self, exc_type, exc, tb): - """Allow use of context manager API. + """ + Allow use of context manager API. - Flushes the batcher and cleans up resources.""" + Flushes the batcher and cleans up resources. + """ self.close() @property @@ -399,7 +411,9 @@ def closed(self) -> bool: return self._closed.is_set() def close(self): - """Flush queue and clean up resources""" + """ + Flush queue and clean up resources + """ self._closed.set() self._flush_timer.cancel() self._schedule_flush() @@ -422,7 +436,8 @@ def _wait_for_batch_results( *tasks: CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] | CrossSync._Sync_Impl.Future[None], ) -> list[Exception]: - """Takes in a list of futures representing _execute_mutate_rows tasks, + """ + Takes in a list of futures representing _execute_mutate_rows tasks, waits for them to complete, and returns a list of errors encountered. Args: diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index 17edfebc6..6136f4009 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -106,7 +106,11 @@ def visit_ListComp(self, node): class HandleCrossSyncDecorators(ast.NodeTransformer): + def visit_FunctionDef(self, node): + return self.visit_AsyncFunctionDef(node) + + def visit_AsyncFunctionDef(self, node): if hasattr(node, "decorator_list"): found_list, node.decorator_list = node.decorator_list, [] for decorator in found_list: @@ -117,7 +121,9 @@ def visit_FunctionDef(self, node): else decorator.attr ) if decorator_type == "convert": - for subcommand in decorator.keywords: + # convert async to sync + node = AsyncToSync().visit(node) + for subcommand in getattr(decorator, "keywords", []): if subcommand.arg == "sync_name": node.name = subcommand.value.s if subcommand.arg == "replace_symbols": @@ -290,7 +296,6 @@ def _transform_class( d for d in cls_ast.decorator_list if "CrossSync" not in ast.dump(d) ] # convert class contents - cls_ast = AsyncToSync().visit(cls_ast) cls_ast = self.cross_sync_converter.visit(cls_ast) if replace_symbols: cls_ast = SymbolReplacer(replace_symbols).visit(cls_ast) From 8c88f08ba33049992fc45d7c11eaedb2e7c98b4b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 8 Jul 2024 14:17:58 -0600 Subject: [PATCH 136/360] cleaning; adding docstrings --- .../cloud/bigtable/data/_sync/_mutate_rows.py | 12 ++-- .../cloud/bigtable/data/_sync/_read_rows.py | 6 +- google/cloud/bigtable/data/_sync/client.py | 65 ++++++------------ .../cloud/bigtable/data/_sync/cross_sync.py | 2 +- .../bigtable/data/_sync/mutations_batcher.py | 43 ++++-------- .../cloud/bigtable/data/_sync/transformers.py | 68 ++++++++++++++----- 6 files changed, 92 insertions(+), 104 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index 060e5184f..d65cf3c61 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -100,12 +100,10 @@ def __init__( self.errors: dict[int, list[Exception]] = {} def start(self): - """ - Start the operation, and run until completion + """Start the operation, and run until completion Raises: - MutationsExceptionGroup: if any mutations failed - """ + MutationsExceptionGroup: if any mutations failed""" try: self._operation() except Exception as exc: @@ -129,14 +127,12 @@ def start(self): raise MutationsExceptionGroup(all_errors, len(self.mutations)) def _run_attempt(self): - """ - Run a single attempt of the mutate_rows rpc. + """Run a single attempt of the mutate_rows rpc. Raises: _MutateRowsIncomplete: if there are failed mutations eligible for retry after the attempt is complete - GoogleAPICallError: if the gapic rpc fails - """ + GoogleAPICallError: if the gapic rpc fails""" request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] active_request_indices = { req_idx: orig_idx for req_idx, orig_idx in enumerate(self.remaining_indices) diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index d74c1a43f..32603a17b 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -142,14 +142,12 @@ def chunk_stream( CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB] ], ) -> CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB.CellChunk]: - """ - process chunks out of raw read_rows stream + """process chunks out of raw read_rows stream Args: stream: the raw read_rows stream from the gapic client Yields: - ReadRowsResponsePB.CellChunk: the next chunk in the stream - """ + ReadRowsResponsePB.CellChunk: the next chunk in the stream""" for resp in stream: resp = resp._pb if resp.last_scanned_row_key: diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index ef6532489..833573365 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -199,9 +199,7 @@ def _start_background_channel_refresh(self) -> None: ) def close(self, timeout: float | None = None): - """ - Cancel all background tasks - """ + """Cancel all background tasks""" self._is_closed.set() for task in self._channel_refresh_tasks: task.cancel() @@ -213,8 +211,7 @@ def close(self, timeout: float | None = None): def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None ) -> list[BaseException | None]: - """ - Prepares the backend for requests on a channel + """Prepares the backend for requests on a channel Pings each Bigtable instance registered in `_active_instances` on the client @@ -257,8 +254,7 @@ def _manage_channel( refresh_interval_max: float = 60 * 45, grace_period: float = 60 * 10, ) -> None: - """ - Background coroutine that periodically refreshes and warms a grpc channel + """Background coroutine that periodically refreshes and warms a grpc channel The backend will automatically close channels after 60 minutes, so `refresh_interval` + `grace_period` should be < 60 minutes @@ -274,8 +270,7 @@ def _manage_channel( process in seconds. Actual interval will be a random value between `refresh_interval_min` and `refresh_interval_max` grace_period: time to allow previous channel to serve existing - requests before closing, in seconds - """ + requests before closing, in seconds""" first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) @@ -579,8 +574,7 @@ def read_rows( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: - """ - Read a set of rows from the table, based on the specified query. + """Read a set of rows from the table, based on the specified query. Retruns results as a list of Row objects when the request is complete. For streamed results, use read_rows_stream. @@ -627,8 +621,7 @@ def read_row( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> Row | None: - """ - Read a single row from the table, based on the specified key. + """Read a single row from the table, based on the specified key. Failed requests within operation_timeout will be retried based on the retryable_errors list until operation_timeout is reached. @@ -675,8 +668,7 @@ def read_rows_sharded( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: - """ - Runs a sharded query in parallel, then return the results in a single list. + """Runs a sharded query in parallel, then return the results in a single list. Results will be returned in the order of the input queries. This function is intended to be run on the results on a query.shard() call. @@ -703,8 +695,7 @@ def read_rows_sharded( list[Row]: a list of Rows returned by the query Raises: ShardedReadRowsExceptionGroup: if any of the queries failed - ValueError: if the query_list is empty - """ + ValueError: if the query_list is empty""" if not sharded_query: raise ValueError("empty sharded_query") operation_timeout, attempt_timeout = _helpers._get_timeouts( @@ -766,8 +757,7 @@ def row_exists( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> bool: - """ - Return a boolean indicating whether the specified row exists in the table. + """Return a boolean indicating whether the specified row exists in the table. uses the filters: chain(limit cells per row = 1, strip value) Args: @@ -812,8 +802,7 @@ def sample_row_keys( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> RowKeySamples: - """ - Return a set of RowKeySamples that delimit contiguous sections of the table of + """Return a set of RowKeySamples that delimit contiguous sections of the table of approximately equal size RowKeySamples output can be used with ReadRowsQuery.shard() to create a sharded query that @@ -929,8 +918,7 @@ def mutate_row( retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ): - """ - Mutates a row atomically. + """Mutates a row atomically. Cells already present in the row are left unchanged unless explicitly changed by ``mutation``. @@ -958,8 +946,7 @@ def mutate_row( GoogleAPIError exceptions from any retries that failed google.api_core.exceptions.GoogleAPIError: raised on non-idempotent operations that cannot be safely retried. - ValueError: if invalid arguments are provided - """ + ValueError: if invalid arguments are provided""" operation_timeout, attempt_timeout = _helpers._get_timeouts( operation_timeout, attempt_timeout, self ) @@ -1052,8 +1039,7 @@ def check_and_mutate_row( false_case_mutations: Mutation | list[Mutation] | None = None, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> bool: - """ - Mutates a row atomically based on the output of a predicate filter + """Mutates a row atomically based on the output of a predicate filter Non-idempotent operation: will not be retried @@ -1082,8 +1068,7 @@ def check_and_mutate_row( Returns: bool indicating whether the predicate was true or false Raises: - google.api_core.exceptions.GoogleAPIError: exceptions from grpc call - """ + google.api_core.exceptions.GoogleAPIError: exceptions from grpc call""" operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) if true_case_mutations is not None and ( not isinstance(true_case_mutations, list) @@ -1116,8 +1101,7 @@ def read_modify_write_row( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> Row: - """ - Reads and modifies a row atomically according to input ReadModifyWriteRules, + """Reads and modifies a row atomically according to input ReadModifyWriteRules, and returns the contents of all modified cells The new value for the timestamp is the greater of the existing timestamp or @@ -1137,8 +1121,7 @@ def read_modify_write_row( Row: a Row containing cell data that was modified as part of the operation Raises: google.api_core.exceptions.GoogleAPIError: exceptions from grpc call - ValueError: if invalid arguments are provided - """ + ValueError: if invalid arguments are provided""" operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) if operation_timeout <= 0: raise ValueError("operation_timeout must be greater than 0") @@ -1159,29 +1142,23 @@ def read_modify_write_row( return Row._from_pb(result.row) def close(self): - """ - Called to close the Table instance and release any resources held by it. - """ + """Called to close the Table instance and release any resources held by it.""" if self._register_instance_future: self._register_instance_future.cancel() self.client._remove_instance_registration(self.instance_id, self) def __enter__(self): - """ - Implement async context manager protocol + """Implement async context manager protocol Ensure registration task has time to run, so that - grpc channels will be warmed for the specified instance - """ + grpc channels will be warmed for the specified instance""" if self._register_instance_future: self._register_instance_future return self def __exit__(self, exc_type, exc_val, exc_tb): - """ - Implement async context manager protocol + """Implement async context manager protocol Unregister this instance with the client, so that - grpc channels will no longer be warmed - """ + grpc channels will no longer be warmed""" self.close() diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index da5e6450a..9052dd9c0 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -348,7 +348,7 @@ def yield_to_event_loop() -> None: files = glob.glob(search_root + "/**/*.py", recursive=True) artifacts: set[transformers.CrossSyncFileArtifact] = set() for file in files: - converter = transformers.CrossSyncClassParser(file) + converter = transformers.CrossSyncClassDecoratorHandler(file) converter.convert_file(artifacts) print(artifacts) for artifact in artifacts: diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 46274980b..006982c1f 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -97,14 +97,12 @@ def _has_capacity(self, additional_count: int, additional_size: int) -> bool: def remove_from_flow( self, mutations: RowMutationEntry | list[RowMutationEntry] ) -> None: - """ - Removes mutations from flow control. This method should be called once + """Removes mutations from flow control. This method should be called once for each mutation that was sent to add_to_flow, after the corresponding operation is complete. Args: - mutations: mutation or list of mutations to remove from flow control - """ + mutations: mutation or list of mutations to remove from flow control""" if not isinstance(mutations, list): mutations = [mutations] total_count = sum((len(entry.mutations) for entry in mutations)) @@ -115,8 +113,7 @@ def remove_from_flow( self._capacity_condition.notify_all() def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): - """ - Generator function that registers mutations with flow control. As mutations + """Generator function that registers mutations with flow control. As mutations are accepted into the flow control, they are yielded back to the caller, to be sent in a batch. If the flow control is at capacity, the generator will block until there is capacity available. @@ -126,8 +123,7 @@ def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): Yields: list[RowMutationEntry]: list of mutations that have reserved space in the flow control. - Each batch contains at least one mutation. - """ + Each batch contains at least one mutation.""" if not isinstance(mutations, list): mutations = [mutations] start_idx = 0 @@ -241,15 +237,13 @@ def __init__( atexit.register(self._on_exit) def _timer_routine(self, interval: float | None) -> None: - """ - Set up a background task to flush the batcher every interval seconds + """Set up a background task to flush the batcher every interval seconds If interval is None, an empty future is returned Args: flush_interval: Automatically flush every flush_interval seconds. - If None, no time-based flushing is performed. - """ + If None, no time-based flushing is performed.""" if not interval or interval <= 0: return None while not self._closed.is_set(): @@ -260,15 +254,13 @@ def _timer_routine(self, interval: float | None) -> None: self._schedule_flush() def append(self, mutation_entry: RowMutationEntry): - """ - Add a new set of mutations to the internal queue + """Add a new set of mutations to the internal queue Args: mutation_entry: new entry to add to flush queue Raises: RuntimeError: if batcher is closed - ValueError: if an invalid mutation type is added - """ + ValueError: if an invalid mutation type is added""" if self._closed.is_set(): raise RuntimeError("Cannot append to closed MutationsBatcher") if isinstance(mutation_entry, Mutation): @@ -304,12 +296,10 @@ def _schedule_flush(self) -> CrossSync._Sync_Impl.Future[None] | None: return None def _flush_internal(self, new_entries: list[RowMutationEntry]): - """ - Flushes a set of mutations to the server, and updates internal state + """Flushes a set of mutations to the server, and updates internal state Args: - new_entries list of RowMutationEntry objects to flush - """ + new_entries list of RowMutationEntry objects to flush""" in_process_requests: list[ CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] ] = [] @@ -397,11 +387,9 @@ def __enter__(self): return self def __exit__(self, exc_type, exc, tb): - """ - Allow use of context manager API. + """Allow use of context manager API. - Flushes the batcher and cleans up resources. - """ + Flushes the batcher and cleans up resources.""" self.close() @property @@ -411,9 +399,7 @@ def closed(self) -> bool: return self._closed.is_set() def close(self): - """ - Flush queue and clean up resources - """ + """Flush queue and clean up resources""" self._closed.set() self._flush_timer.cancel() self._schedule_flush() @@ -436,8 +422,7 @@ def _wait_for_batch_results( *tasks: CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] | CrossSync._Sync_Impl.Future[None], ) -> list[Exception]: - """ - Takes in a list of futures representing _execute_mutate_rows tasks, + """Takes in a list of futures representing _execute_mutate_rows tasks, waits for them to complete, and returns a list of errors encountered. Args: diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index 6136f4009..c22307712 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -19,6 +19,11 @@ class SymbolReplacer(ast.NodeTransformer): + """ + Replaces all instances of a symbol in an AST with a replacement + + Works for function signatures, method calls, docstrings, and type annotations + """ def __init__(self, replacements: dict[str, str]): self.replacements = replacements @@ -37,36 +42,46 @@ def visit_Attribute(self, node): node, ) - def update_docstring(self, docstring): + def visit_AsyncFunctionDef(self, node): """ - Update docstring to replace any key words in the replacements dict + Replace async function docstrings """ - if not docstring: - return docstring - for key_word, replacement in self.replacements.items(): - docstring = docstring.replace(f" {key_word} ", f" {replacement} ") - return docstring + # use same logic as FunctionDef + return self.visit_FunctionDef(node) def visit_FunctionDef(self, node): - # replace docstring - docstring = self.update_docstring(ast.get_docstring(node)) - if isinstance(node.body[0], ast.Expr) and isinstance( + """ + Replace function docstrings + """ + docstring = ast.get_docstring(node) + if docstring and isinstance(node.body[0], ast.Expr) and isinstance( node.body[0].value, ast.Str ): + for key_word, replacement in self.replacements.items(): + docstring = docstring.replace(f" {key_word} ", f" {replacement} ") node.body[0].value.s = docstring return self.generic_visit(node) def visit_Str(self, node): - """Used to replace string type annotations""" + """Replace string type annotations""" node.s = self.replacements.get(node.s, node.s) return node class AsyncToSync(ast.NodeTransformer): + """ + Replaces or strips all async keywords from a given AST + """ def visit_Await(self, node): + """ + Strips await keyword + """ return self.visit(node.value) def visit_AsyncFor(self, node): + """ + Replaces `async for` with `for` + """ return ast.copy_location( ast.For( self.visit(node.target), @@ -78,6 +93,9 @@ def visit_AsyncFor(self, node): ) def visit_AsyncWith(self, node): + """ + Replaces `async with` with `with` + """ return ast.copy_location( ast.With( [self.visit(item) for item in node.items], @@ -87,6 +105,9 @@ def visit_AsyncWith(self, node): ) def visit_AsyncFunctionDef(self, node): + """ + Replaces `async def` with `def` + """ return ast.copy_location( ast.FunctionDef( node.name, @@ -99,13 +120,18 @@ def visit_AsyncFunctionDef(self, node): ) def visit_ListComp(self, node): - # replace [x async for ...] with [x for ...] + """ + Replaces `async for` with `for` in list comprehensions + """ for generator in node.generators: generator.is_async = False return self.generic_visit(node) -class HandleCrossSyncDecorators(ast.NodeTransformer): +class CrossSyncMethodDecoratorHandler(ast.NodeTransformer): + """ + Visits each method in a class, and handles any CrossSync decorators found + """ def visit_FunctionDef(self, node): return self.visit_AsyncFunctionDef(node) @@ -213,14 +239,20 @@ def render(self, with_black=True, save_to_disk=False) -> str: return full_str -class CrossSyncClassParser(ast.NodeTransformer): +class CrossSyncClassDecoratorHandler(ast.NodeTransformer): + """ + Visits each class in the file, and if it has a CrossSync decorator, it will be transformed. + + Uses CrossSyncMethodDecoratorHandler to visit and (potentially) convert each method in the class + """ def __init__(self, file_path): self.in_path = file_path self._artifact_dict: dict[str, CrossSyncFileArtifact] = {} self.imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = [] - self.cross_sync_converter = SymbolReplacer( + self.cross_sync_symbol_transformer = SymbolReplacer( {"CrossSync": "CrossSync._Sync_Impl"} ) + self.cross_sync_method_handler = CrossSyncMethodDecoratorHandler() def convert_file( self, artifacts: set[CrossSyncFileArtifact] | None = None @@ -296,10 +328,10 @@ def _transform_class( d for d in cls_ast.decorator_list if "CrossSync" not in ast.dump(d) ] # convert class contents - cls_ast = self.cross_sync_converter.visit(cls_ast) + cls_ast = self.cross_sync_symbol_transformer.visit(cls_ast) if replace_symbols: cls_ast = SymbolReplacer(replace_symbols).visit(cls_ast) - cls_ast = HandleCrossSyncDecorators().visit(cls_ast) + cls_ast = self.cross_sync_method_handler.visit(cls_ast) return cls_ast def _get_imports( @@ -311,7 +343,7 @@ def _get_imports( imports = [] for node in tree.body: if isinstance(node, (ast.Import, ast.ImportFrom, ast.Try, ast.If)): - imports.append(self.cross_sync_converter.visit(node)) + imports.append(self.cross_sync_symbol_transformer.visit(node)) return imports def _convert_ast_to_py(self, ast_node): From c93597b5036c7a44ef0a70fad001fa571778feca Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 8 Jul 2024 19:37:20 -0700 Subject: [PATCH 137/360] use custom class for decorators --- .../bigtable/data/_async/_mutate_rows.py | 2 +- .../cloud/bigtable/data/_async/_read_rows.py | 2 +- google/cloud/bigtable/data/_async/client.py | 4 +- .../bigtable/data/_async/mutations_batcher.py | 4 +- .../cloud/bigtable/data/_sync/cross_sync.py | 129 ++++++++++++----- .../cloud/bigtable/data/_sync/transformers.py | 131 +++++++----------- tests/system/data/test_system_async.py | 4 +- tests/unit/data/_async/test__mutate_rows.py | 2 +- tests/unit/data/_async/test__read_rows.py | 2 +- tests/unit/data/_async/test_client.py | 18 +-- .../data/_async/test_mutations_batcher.py | 4 +- .../data/_async/test_read_rows_acceptance.py | 2 +- 12 files changed, 163 insertions(+), 141 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index fee05fc39..bd5feefe7 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -63,7 +63,7 @@ class _EntryWithProto: # noqa: F811 @CrossSync.sync_output( - "google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", + path="google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", ) class _MutateRowsOperationAsync: """ diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index b375e9d03..7f98889d9 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -52,7 +52,7 @@ def __init__(self, chunk): @CrossSync.sync_output( - "google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", + path="google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", ) class _ReadRowsOperationAsync: """ diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 237d66fbf..aebeb5315 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -111,7 +111,7 @@ @CrossSync.sync_output( - "google.cloud.bigtable.data._sync.client.BigtableDataClient", + path="google.cloud.bigtable.data._sync.client.BigtableDataClient", ) class BigtableDataClientAsync(ClientWithProject): @CrossSync.convert( @@ -493,7 +493,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) -@CrossSync.sync_output("google.cloud.bigtable.data._sync.client.Table") +@CrossSync.sync_output(path="google.cloud.bigtable.data._sync.client.Table") class TableAsync: """ Main Data API surface diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 4a83683dd..f7404d0b8 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -52,7 +52,7 @@ @CrossSync.sync_output( - "google.cloud.bigtable.data._sync.mutations_batcher._FlowControl" + path="google.cloud.bigtable.data._sync.mutations_batcher._FlowControl" ) class _FlowControlAsync: """ @@ -183,7 +183,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] @CrossSync.sync_output( - "google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", + path="google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", mypy_ignore=["unreachable"], ) class MutationsBatcherAsync: diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 9052dd9c0..e439cbb6b 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -41,9 +41,88 @@ T = TypeVar("T") +class AstDecorator: + """ + Helper class for CrossSync decorators used for guiding ast transformations. + + These decorators provide arguments that are used during the code generation process, + but act as no-ops when encountered in live code + """ + + def __init__(self, name, required_keywords=(), **default_kwargs): + self.name = name + self.required_kwargs = required_keywords + self.default_kwargs = default_kwargs + self.all_valid_keys = [*required_keywords, *default_kwargs.keys()] + + def __call__(self, *args, **kwargs): + for kwarg in kwargs: + if kwarg not in self.all_valid_keys: + raise ValueError(f"Invalid keyword argument: {kwarg}") + if len(args) == 1 and callable(args[0]): + return args[0] + def decorator(func): + return func + return decorator -class CrossSync: - SyncImports = False + def parse_ast_keywords(self, node): + got_kwargs = { + kw.arg: self._convert_ast_to_py(kw.value) + for kw in node.keywords + } if hasattr(node, "keywords") else {} + for key in got_kwargs.keys(): + if key not in self.all_valid_keys: + raise ValueError(f"Invalid keyword argument: {key}") + for key in self.required_kwargs: + if key not in got_kwargs: + raise ValueError(f"Missing required keyword argument: {key}") + return {**self.default_kwargs, **got_kwargs} + + def _convert_ast_to_py(self, ast_node): + """ + Helper to convert ast primitives to python primitives. Used when unwrapping kwargs + """ + import ast + if isinstance(ast_node, ast.Constant): + return ast_node.value + if isinstance(ast_node, ast.List): + return [self._convert_ast_to_py(node) for node in ast_node.elts] + if isinstance(ast_node, ast.Dict): + return { + self._convert_ast_to_py(k): self._convert_ast_to_py(v) + for k, v in zip(ast_node.keys, ast_node.values) + } + raise ValueError(f"Unsupported type {type(ast_node)}") + + def _node_eq(self, node: ast.Node): + import ast + if "CrossSync" in ast.dump(node): + decorator_type = ( + node.func.attr + if hasattr(node, "func") + else node.attr + ) + if decorator_type == self.name: + return True + return False + + def __eq__(self, other): + return self._node_eq(other) + + +class _DecoratorMeta(type): + """ + Metaclass to attach AstDecorator objects in internal self._decorators + as attributes + """ + + def __getattr__(self, name): + for decorator in self._decorators: + if name == decorator.name: + return decorator + return super().__getattr__(name) + +class CrossSync(metaclass=_DecoratorMeta): is_async = True sleep = asyncio.sleep @@ -63,24 +142,19 @@ class CrossSync: Iterator: TypeAlias = AsyncIterator Generator: TypeAlias = AsyncGenerator - generated_replacements: dict[type, str] = {} - - @staticmethod - def convert( - *args, sync_name: str | None = None, replace_symbols: dict[str, str] | None = None - ): - if args: - # only positional argument should be the function to wrap. Return it directly - return args[0] - - def decorator(func): - return func - - return decorator - - @staticmethod - def drop_method(func): - return func + _decorators: list[AstDecorator] = [ + AstDecorator("sync_output", # decorate classes to convert + required_keywords=["path"], # otput path for generated sync class + replace_symbols={}, # replace specific symbols across entire class + mypy_ignore=(), # set of mypy error codes to ignore in output file + include_file_imports=True # when True, import statements from top of file will be included in output file + ), + AstDecorator("convert", # decorate methods to convert from async to sync + sync_name=None, # use a new name for the sync class + replace_symbols={} # replace specific symbols within the function + ), + AstDecorator("drop_method"), # decorate methods to drop in sync version of class + ] @classmethod def Mock(cls, *args, **kwargs): @@ -90,21 +164,6 @@ def Mock(cls, *args, **kwargs): from mock import AsyncMock # type: ignore return AsyncMock(*args, **kwargs) - @classmethod - def sync_output( - cls, - sync_path: str, - *, - replace_symbols: dict["str", "str" | None] | None = None, - mypy_ignore: list[str] | None = None, - include_file_imports: bool = False, - ): - # return the async class unchanged - def decorator(async_cls): - return async_cls - - return decorator - @staticmethod def pytest(func): import pytest diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index c22307712..58fd6a36b 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -16,6 +16,7 @@ import ast from dataclasses import dataclass, field +from .cross_sync import CrossSync class SymbolReplacer(ast.NodeTransformer): @@ -137,46 +138,25 @@ def visit_FunctionDef(self, node): return self.visit_AsyncFunctionDef(node) def visit_AsyncFunctionDef(self, node): - if hasattr(node, "decorator_list"): - found_list, node.decorator_list = node.decorator_list, [] - for decorator in found_list: - if "CrossSync" in ast.dump(decorator): - decorator_type = ( - decorator.func.attr - if hasattr(decorator, "func") - else decorator.attr - ) - if decorator_type == "convert": + try: + if hasattr(node, "decorator_list"): + found_list, node.decorator_list = node.decorator_list, [] + for decorator in found_list: + if decorator == CrossSync.convert: + kwargs = CrossSync.convert.parse_ast_keywords(decorator) # convert async to sync node = AsyncToSync().visit(node) - for subcommand in getattr(decorator, "keywords", []): - if subcommand.arg == "sync_name": - node.name = subcommand.value.s - if subcommand.arg == "replace_symbols": - replacements = { - subcommand.value.keys[i] - .s: subcommand.value.values[i] - .s - for i in range(len(subcommand.value.keys)) - } - node = SymbolReplacer(replacements).visit(node) - elif decorator_type == "pytest": - pass - elif decorator_type == "pytest_fixture": - # keep decorator - node.decorator_list.append(decorator) - elif decorator_type == "Retry": - node.decorator_list.append(decorator) - elif decorator_type == "drop_method": + if kwargs["sync_name"] is not None: + node.name = kwargs["sync_name"] + if kwargs["replace_symbols"]: + node = SymbolReplacer(kwargs["replace_symbols"]).visit(node) + elif decorator == CrossSync.drop_method: return None else: - raise ValueError( - f"Unsupported CrossSync decorator: {decorator_type}" - ) - else: - # add non-crosssync decorators back - node.decorator_list.append(decorator) - return node + node.decorator_list.append(decorator) + return node + except ValueError as e: + raise ValueError(f"node {node.name} failed") from e @dataclass @@ -275,40 +255,36 @@ def visit_ClassDef(self, node): Called for each class in file. If class has a CrossSync decorator, it will be transformed according to the decorator arguments """ - for decorator in node.decorator_list: - if "CrossSync" in ast.dump(decorator): - kwargs = { - kw.arg: self._convert_ast_to_py(kw.value) - for kw in decorator.keywords - } - # find the path to write the sync class to - sync_path = kwargs.pop("sync_path", None) - if not sync_path: - sync_path = decorator.args[0].s - out_file = "/".join(sync_path.rsplit(".")[:-1]) + ".py" - sync_cls_name = sync_path.rsplit(".", 1)[-1] - # find the artifact file for the save location - output_artifact = self._artifact_dict.get( - out_file, CrossSyncFileArtifact(out_file) - ) - # write converted class details if not already present - if sync_cls_name not in output_artifact.contained_classes: - converted = self._transform_class(node, sync_cls_name, **kwargs) - output_artifact.converted_classes.append(converted) - # handle file-level mypy ignores - mypy_ignores = [ - s - for s in kwargs.get("mypy_ignore", []) - if s not in output_artifact.mypy_ignore - ] - output_artifact.mypy_ignore.extend(mypy_ignores) - # handle file-level imports - if not output_artifact.imports and kwargs.get( - "include_file_imports", True - ): - output_artifact.imports = self.imports - self._artifact_dict[out_file] = output_artifact - return node + try: + for decorator in node.decorator_list: + if decorator == CrossSync.sync_output: + kwargs = CrossSync.sync_output.parse_ast_keywords(decorator) + # find the path to write the sync class to + sync_path = kwargs["path"] + out_file = "/".join(sync_path.rsplit(".")[:-1]) + ".py" + sync_cls_name = sync_path.rsplit(".", 1)[-1] + # find the artifact file for the save location + output_artifact = self._artifact_dict.get( + out_file, CrossSyncFileArtifact(out_file) + ) + # write converted class details if not already present + if sync_cls_name not in output_artifact.contained_classes: + converted = self._transform_class(node, sync_cls_name, **kwargs) + output_artifact.converted_classes.append(converted) + # handle file-level mypy ignores + mypy_ignores = [ + s + for s in kwargs["mypy_ignore"] + if s not in output_artifact.mypy_ignore + ] + output_artifact.mypy_ignore.extend(mypy_ignores) + # handle file-level imports + if not output_artifact.imports and kwargs["include_file_imports"]: + output_artifact.imports = self.imports + self._artifact_dict[out_file] = output_artifact + return node + except ValueError as e: + raise ValueError(f"failed for class: {node.name}") from e def _transform_class( self, @@ -346,17 +322,4 @@ def _get_imports( imports.append(self.cross_sync_symbol_transformer.visit(node)) return imports - def _convert_ast_to_py(self, ast_node): - """ - Helper to convert ast primitives to python primitives. Used when unwrapping kwargs - """ - if isinstance(ast_node, ast.Constant): - return ast_node.value - if isinstance(ast_node, ast.List): - return [self._convert_ast_to_py(node) for node in ast_node.elts] - if isinstance(ast_node, ast.Dict): - return { - self._convert_ast_to_py(k): self._convert_ast_to_py(v) - for k, v in zip(ast_node.keys, ast_node.values) - } - raise ValueError(f"Unsupported type {type(ast_node)}") + diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 2f259d6eb..6b9b14e3d 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -34,7 +34,7 @@ TEST_FAMILY = "test-family" TEST_FAMILY_2 = "test-family-2" -@CrossSync.sync_output("tests.system.data.test_system.TempRowBuilder") +@CrossSync.sync_output(path="tests.system.data.test_system.TempRowBuilder") class TempRowBuilderAsync: """ Used to add rows to a table for testing purposes. @@ -79,7 +79,7 @@ async def delete_rows(self): await self.table.client._gapic_client.mutate_rows(request) -@CrossSync.sync_output("tests.system.data.test_system.TestSystem") +@CrossSync.sync_output(path="tests.system.data.test_system.TestSystem") class TestSystemAsync: @CrossSync.convert(replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}) diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 8743182a2..a7f3fa195 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -29,7 +29,7 @@ @CrossSync.sync_output( - "tests.unit.data._sync.test__mutate_rows.TestMutateRowsOperation", + path="tests.unit.data._sync.test__mutate_rows.TestMutateRowsOperation", ) class TestMutateRowsOperation: def _target_class(self): diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 405e94e57..11fbc1bc1 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -35,7 +35,7 @@ @CrossSync.sync_output( - "tests.unit.data._sync.test__read_rows.TestReadRowsOperation", + path="tests.unit.data._sync.test__read_rows.TestReadRowsOperation", ) class TestReadRowsOperationAsync: """ diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index a2b468cb8..0cfdb2e70 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -77,7 +77,7 @@ @CrossSync.sync_output( - "tests.unit.data._sync.test_client.TestBigtableDataClient", + path="tests.unit.data._sync.test_client.TestBigtableDataClient", replace_symbols={ "TestTableAsync": "TestTable", "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", @@ -1142,7 +1142,7 @@ def test_client_ctor_sync(self): assert client._channel_refresh_tasks == [] -@CrossSync.sync_output("tests.unit.data._sync.test_client.TestTable") +@CrossSync.sync_output(path="tests.unit.data._sync.test_client.TestTable") class TestTableAsync: @CrossSync.convert( replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} @@ -1456,7 +1456,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ assert "app_profile_id=" not in goog_metadata -@CrossSync.sync_output("tests.unit.data._sync.test_client.TestReadRows") +@CrossSync.sync_output(path="tests.unit.data._sync.test_client.TestReadRows") class TestReadRowsAsync: """ Tests for table.read_rows and related methods. @@ -1972,7 +1972,7 @@ async def test_row_exists(self, return_value, expected_result): assert query.filter._to_dict() == expected_filter -@CrossSync.sync_output("tests.unit.data._sync.test_client.TestReadRowsSharded") +@CrossSync.sync_output(path="tests.unit.data._sync.test_client.TestReadRowsSharded") class TestReadRowsShardedAsync: @CrossSync.convert( replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} @@ -2197,7 +2197,7 @@ async def mock_call(*args, **kwargs): ) -@CrossSync.sync_output("tests.unit.data._sync.test_client.TestSampleRowKeys") +@CrossSync.sync_output(path="tests.unit.data._sync.test_client.TestSampleRowKeys") class TestSampleRowKeysAsync: @CrossSync.convert( replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} @@ -2353,7 +2353,7 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio @CrossSync.sync_output( - "tests.unit.data._sync.test_client.TestMutateRow", + path="tests.unit.data._sync.test_client.TestMutateRow", ) class TestMutateRowAsync: @CrossSync.convert( @@ -2534,7 +2534,7 @@ async def test_mutate_row_no_mutations(self, mutations): @CrossSync.sync_output( - "tests.unit.data._sync.test_client.TestBulkMutateRows", + path="tests.unit.data._sync.test_client.TestBulkMutateRows", ) class TestBulkMutateRowsAsync: @CrossSync.convert( @@ -2918,7 +2918,7 @@ async def test_bulk_mutate_error_recovery(self): await table.bulk_mutate_rows(entries, operation_timeout=1000) -@CrossSync.sync_output("tests.unit.data._sync.test_client.TestCheckAndMutateRow") +@CrossSync.sync_output(path="tests.unit.data._sync.test_client.TestCheckAndMutateRow") class TestCheckAndMutateRowAsync: @CrossSync.convert( replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} @@ -3073,7 +3073,7 @@ async def test_check_and_mutate_mutations_parsing(self): ) -@CrossSync.sync_output("tests.unit.data._sync.test_client.TestReadModifyWriteRow") +@CrossSync.sync_output(path="tests.unit.data._sync.test_client.TestReadModifyWriteRow") class TestReadModifyWriteRowAsync: @CrossSync.convert( replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 765ac4e13..1d8a75a6d 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -43,7 +43,7 @@ import mock # type: ignore -@CrossSync.sync_output("tests.unit.data._sync.test_mutations_batcher.Test_FlowControl") +@CrossSync.sync_output(path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl") class Test_FlowControl: @staticmethod @CrossSync.convert(replace_symbols={"_FlowControlAsync": "_FlowControl"}) @@ -320,7 +320,7 @@ async def test_add_to_flow_oversize(self): @CrossSync.sync_output( - "tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" + path="tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" ) class TestMutationsBatcherAsync: @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 4a9939abf..0f4996d9a 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -65,7 +65,7 @@ class TestFile(proto.Message): # noqa: F811 @CrossSync.sync_output( - "tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", + path="tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", ) class TestReadRowsAcceptanceAsync: @staticmethod From 32f16319e6ed0ecfd50d85d9a6b2ce669831cc37 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 13:32:59 -0700 Subject: [PATCH 138/360] mark pytest methods for conversion --- tests/system/data/test_system_async.py | 5 +++++ tests/unit/data/_async/test__mutate_rows.py | 1 + tests/unit/data/_async/test_client.py | 4 ++++ tests/unit/data/_async/test_mutations_batcher.py | 1 + tests/unit/data/_async/test_read_rows_acceptance.py | 2 ++ 5 files changed, 13 insertions(+) diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 6b9b14e3d..3e67c6fe3 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -44,6 +44,7 @@ def __init__(self, table): self.rows = [] self.table = table + @CrossSync.convert async def add_row( self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" ): @@ -67,6 +68,7 @@ async def add_row( await self.table.client._gapic_client.mutate_row(request) self.rows.append(row_key) + @CrossSync.convert async def delete_rows(self): if self.rows: request = { @@ -89,6 +91,7 @@ async def client(self): async with BigtableDataClientAsync(project=project, pool_size=4) as client: yield client + @CrossSync.convert @CrossSync.pytest_fixture(scope="session") async def table(self, client, table_id, instance_id): async with client.get_table( @@ -136,6 +139,7 @@ def cluster_config(self, project_id): } return cluster + @CrossSync.convert @pytest.mark.usefixtures("table") async def _retrieve_cell_value(self, table, row_key): """ @@ -149,6 +153,7 @@ async def _retrieve_cell_value(self, table, row_key): cell = row.cells[0] return cell.value + @CrossSync.convert async def _create_row_and_mutation( self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" ): diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index a7f3fa195..8be885a92 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -62,6 +62,7 @@ def _make_mutation(self, count=1, size=1): mutation.mutations = [mock.Mock()] * count return mutation + @CrossSync.convert async def _mock_stream(self, mutation_list, error_dict): for idx, entry in enumerate(mutation_list): code = error_dict.get(idx, 0) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 0cfdb2e70..ecfd9af2b 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1523,6 +1523,7 @@ def _make_chunk(*args, **kwargs): return ReadRowsResponse.CellChunk(*args, **kwargs) @staticmethod + @CrossSync.convert async def _make_gapic_stream( chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0, @@ -1561,6 +1562,7 @@ def cancel(self): return mock_stream(chunk_list, sleep_time) + @CrossSync.convert async def execute_fn(self, table, *args, **kwargs): return await table.read_rows(*args, **kwargs) @@ -2205,6 +2207,7 @@ class TestSampleRowKeysAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @CrossSync.convert async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): from google.cloud.bigtable_v2.types import SampleRowKeysResponse @@ -2543,6 +2546,7 @@ class TestBulkMutateRowsAsync: def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @CrossSync.convert async def _mock_response(self, response_list): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 1d8a75a6d..3d31eeacf 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -906,6 +906,7 @@ async def gen(x): instance._oldest_exceptions.clear() instance._newest_exceptions.clear() + @CrossSync.convert async def _mock_gapic_return(self, num=5): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 0f4996d9a..5b87d81b0 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -108,9 +108,11 @@ def extract_results_from_row(row: Row): return results @staticmethod + @CrossSync.convert async def _coro_wrapper(stream): return stream + @CrossSync.convert async def _process_chunks(self, *chunks): async def _row_stream(): yield ReadRowsResponse(chunks=chunks) From b6fb1d55f64e8506788dca945062f36bf962ff2b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 13:33:54 -0700 Subject: [PATCH 139/360] convert crosssync.pytest to sync --- .../cloud/bigtable/data/_sync/cross_sync.py | 19 ++++++++++++------- .../cloud/bigtable/data/_sync/transformers.py | 3 +++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index e439cbb6b..28b7c653a 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -41,6 +41,13 @@ T = TypeVar("T") +def pytest_mark_asyncio(func): + try: + import pytest + return pytest.mark.asyncio(func) + except ImportError: + return func + class AstDecorator: """ Helper class for CrossSync decorators used for guiding ast transformations. @@ -49,16 +56,19 @@ class AstDecorator: but act as no-ops when encountered in live code """ - def __init__(self, name, required_keywords=(), **default_kwargs): + def __init__(self, name, required_keywords=(), inner_decorator=None, **default_kwargs): self.name = name self.required_kwargs = required_keywords self.default_kwargs = default_kwargs self.all_valid_keys = [*required_keywords, *default_kwargs.keys()] + self.inner_decorator = inner_decorator def __call__(self, *args, **kwargs): for kwarg in kwargs: if kwarg not in self.all_valid_keys: raise ValueError(f"Invalid keyword argument: {kwarg}") + if self.inner_decorator: + return self.inner_decorator(*args, **kwargs) if len(args) == 1 and callable(args[0]): return args[0] def decorator(func): @@ -154,6 +164,7 @@ class CrossSync(metaclass=_DecoratorMeta): replace_symbols={} # replace specific symbols within the function ), AstDecorator("drop_method"), # decorate methods to drop in sync version of class + AstDecorator("pytest", inner_decorator=pytest_mark_asyncio) ] @classmethod @@ -164,12 +175,6 @@ def Mock(cls, *args, **kwargs): from mock import AsyncMock # type: ignore return AsyncMock(*args, **kwargs) - @staticmethod - def pytest(func): - import pytest - - return pytest.mark.asyncio(func) - @staticmethod def pytest_fixture(*args, **kwargs): import pytest_asyncio diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index 58fd6a36b..f8efb34c0 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -152,6 +152,9 @@ def visit_AsyncFunctionDef(self, node): node = SymbolReplacer(kwargs["replace_symbols"]).visit(node) elif decorator == CrossSync.drop_method: return None + elif decorator == CrossSync.pytest: + # also convert pytest methods to sync + node = AsyncToSync().visit(node) else: node.decorator_list.append(decorator) return node From f90d54cbe0b4b1575262521a209fb6572870b33e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 13:47:13 -0700 Subject: [PATCH 140/360] removed unneeded check --- tests/unit/data/_async/test_client.py | 1 - tests/unit/data/_sync/test_client.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index ecfd9af2b..1b95956d7 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1090,7 +1090,6 @@ async def test_close(self): close_mock.assert_awaited() for task in tasks_list: assert task.done() - assert client._channel_refresh_tasks == [] @CrossSync.pytest async def test_close_with_timeout(self): diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index e25fea3ac..a6415a0d2 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -918,7 +918,6 @@ def test_close(self): close_mock.assert_awaited() for task in tasks_list: assert task.done() - assert client._channel_refresh_tasks == [] def test_close_with_timeout(self): pool_size = 7 From f5dfa3ef0cc113cff43080e1f5fc62da3fa70942 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 14:08:51 -0700 Subject: [PATCH 141/360] use AstDecorator for pytest_fixture --- .../cloud/bigtable/data/_sync/cross_sync.py | 36 +++++++++---------- .../cloud/bigtable/data/_sync/transformers.py | 11 +++++- tests/system/data/test_system.py | 6 ++-- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 28b7c653a..f8d0ff232 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -48,6 +48,12 @@ def pytest_mark_asyncio(func): except ImportError: return func +def pytest_asyncio_fixture(*args, **kwargs): + import pytest_asyncio + def decorator(func): + return pytest_asyncio.fixture(*args, **kwargs)(func) + return decorator + class AstDecorator: """ Helper class for CrossSync decorators used for guiding ast transformations. @@ -56,8 +62,8 @@ class AstDecorator: but act as no-ops when encountered in live code """ - def __init__(self, name, required_keywords=(), inner_decorator=None, **default_kwargs): - self.name = name + def __init__(self, decorator_name, required_keywords=(), inner_decorator=None, **default_kwargs): + self.name = decorator_name self.required_kwargs = required_keywords self.default_kwargs = default_kwargs self.all_valid_keys = [*required_keywords, *default_kwargs.keys()] @@ -161,10 +167,18 @@ class CrossSync(metaclass=_DecoratorMeta): ), AstDecorator("convert", # decorate methods to convert from async to sync sync_name=None, # use a new name for the sync class - replace_symbols={} # replace specific symbols within the function + replace_symbols={}, # replace specific symbols within the function ), AstDecorator("drop_method"), # decorate methods to drop in sync version of class - AstDecorator("pytest", inner_decorator=pytest_mark_asyncio) + AstDecorator("pytest", inner_decorator=pytest_mark_asyncio), # decorate test methods to run with pytest-asyncio + AstDecorator("pytest_fixture", # decorate test methods to run with pytest fixture + inner_decorator=pytest_asyncio_fixture, + scope="function", + params=None, + autouse=False, + ids=None, + name=None, + ), ] @classmethod @@ -175,13 +189,6 @@ def Mock(cls, *args, **kwargs): from mock import AsyncMock # type: ignore return AsyncMock(*args, **kwargs) - @staticmethod - def pytest_fixture(*args, **kwargs): - import pytest_asyncio - def decorator(func): - return pytest_asyncio.fixture(*args, **kwargs)(func) - return decorator - @staticmethod async def gather_partials( partial_list: Sequence[Callable[[], Awaitable[T]]], @@ -350,13 +357,6 @@ def event_wait( ) -> None: event.wait(timeout=timeout) - @staticmethod - def pytest_fixture(*args, **kwargs): - import pytest - def decorator(func): - return pytest.fixture(*args, **kwargs)(func) - return decorator - @staticmethod def gather_partials( partial_list: Sequence[Callable[[], T]], diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index f8efb34c0..a7a17329d 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -143,19 +143,28 @@ def visit_AsyncFunctionDef(self, node): found_list, node.decorator_list = node.decorator_list, [] for decorator in found_list: if decorator == CrossSync.convert: - kwargs = CrossSync.convert.parse_ast_keywords(decorator) # convert async to sync + kwargs = CrossSync.convert.parse_ast_keywords(decorator) node = AsyncToSync().visit(node) + # replace method name if specified if kwargs["sync_name"] is not None: node.name = kwargs["sync_name"] + # replace symbols if specified if kwargs["replace_symbols"]: node = SymbolReplacer(kwargs["replace_symbols"]).visit(node) elif decorator == CrossSync.drop_method: + # drop method entirely from class return None elif decorator == CrossSync.pytest: # also convert pytest methods to sync node = AsyncToSync().visit(node) + elif decorator == CrossSync.pytest_fixture: + # add pytest.fixture decorator + decorator.func.value = ast.Name(id="pytest", ctx=ast.Load()) + decorator.func.attr = "fixture" + node.decorator_list.append(decorator) else: + # keep unknown decorators node.decorator_list.append(decorator) return node except ValueError as e: diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py index 46605cf4e..e7330bf57 100644 --- a/tests/system/data/test_system.py +++ b/tests/system/data/test_system.py @@ -74,13 +74,13 @@ def delete_rows(self): class TestSystem: - @CrossSync._Sync_Impl.pytest_fixture(scope="session") + @pytest.fixture(scope="session") def client(self): project = os.getenv("GOOGLE_CLOUD_PROJECT") or None with BigtableDataClient(project=project, pool_size=4) as client: yield client - @CrossSync._Sync_Impl.pytest_fixture(scope="session") + @pytest.fixture(scope="session") def table(self, client, table_id, instance_id): with client.get_table(instance_id, table_id) as table: yield table @@ -136,7 +136,7 @@ def _create_row_and_mutation( mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) return (row_key, mutation) - @CrossSync._Sync_Impl.pytest_fixture(scope="function") + @pytest.fixture(scope="function") def temp_rows(self, table): builder = TempRowBuilder(table) yield builder From ce45742428c9d3a1318635dd03e8adcf84e43b41 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 14:15:14 -0700 Subject: [PATCH 142/360] renamed class decorator --- .../cloud/bigtable/data/_async/_mutate_rows.py | 2 +- .../cloud/bigtable/data/_async/_read_rows.py | 2 +- google/cloud/bigtable/data/_async/client.py | 4 ++-- .../bigtable/data/_async/mutations_batcher.py | 4 ++-- google/cloud/bigtable/data/_sync/cross_sync.py | 4 ++-- .../cloud/bigtable/data/_sync/transformers.py | 4 ++-- tests/system/data/test_system_async.py | 4 ++-- tests/unit/data/_async/test__mutate_rows.py | 2 +- tests/unit/data/_async/test__read_rows.py | 2 +- tests/unit/data/_async/test_client.py | 18 +++++++++--------- .../unit/data/_async/test_mutations_batcher.py | 4 ++-- .../data/_async/test_read_rows_acceptance.py | 2 +- 12 files changed, 26 insertions(+), 26 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index bd5feefe7..6d4d2f2e8 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -62,7 +62,7 @@ class _EntryWithProto: # noqa: F811 proto: types_pb.MutateRowsRequest.Entry -@CrossSync.sync_output( +@CrossSync.export_sync( path="google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", ) class _MutateRowsOperationAsync: diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 7f98889d9..8c982427c 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -51,7 +51,7 @@ def __init__(self, chunk): self.chunk = chunk -@CrossSync.sync_output( +@CrossSync.export_sync( path="google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", ) class _ReadRowsOperationAsync: diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index aebeb5315..e7d84ebf1 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -110,7 +110,7 @@ from google.cloud.bigtable.data._helpers import ShardedQuery -@CrossSync.sync_output( +@CrossSync.export_sync( path="google.cloud.bigtable.data._sync.client.BigtableDataClient", ) class BigtableDataClientAsync(ClientWithProject): @@ -493,7 +493,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) -@CrossSync.sync_output(path="google.cloud.bigtable.data._sync.client.Table") +@CrossSync.export_sync(path="google.cloud.bigtable.data._sync.client.Table") class TableAsync: """ Main Data API surface diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index f7404d0b8..b9a6a3339 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -51,7 +51,7 @@ from google.cloud.bigtable.data._sync.client import Table # noqa: F401 -@CrossSync.sync_output( +@CrossSync.export_sync( path="google.cloud.bigtable.data._sync.mutations_batcher._FlowControl" ) class _FlowControlAsync: @@ -182,7 +182,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] yield mutations[start_idx:end_idx] -@CrossSync.sync_output( +@CrossSync.export_sync( path="google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", mypy_ignore=["unreachable"], ) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index f8d0ff232..dd87a63b5 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -136,7 +136,7 @@ def __getattr__(self, name): for decorator in self._decorators: if name == decorator.name: return decorator - return super().__getattr__(name) + raise AttributeError(f"CrossSync has no attribute {name}") class CrossSync(metaclass=_DecoratorMeta): is_async = True @@ -159,7 +159,7 @@ class CrossSync(metaclass=_DecoratorMeta): Generator: TypeAlias = AsyncGenerator _decorators: list[AstDecorator] = [ - AstDecorator("sync_output", # decorate classes to convert + AstDecorator("export_sync", # decorate classes to convert required_keywords=["path"], # otput path for generated sync class replace_symbols={}, # replace specific symbols across entire class mypy_ignore=(), # set of mypy error codes to ignore in output file diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py index a7a17329d..60498763d 100644 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ b/google/cloud/bigtable/data/_sync/transformers.py @@ -269,8 +269,8 @@ def visit_ClassDef(self, node): """ try: for decorator in node.decorator_list: - if decorator == CrossSync.sync_output: - kwargs = CrossSync.sync_output.parse_ast_keywords(decorator) + if decorator == CrossSync.export_sync: + kwargs = CrossSync.export_sync.parse_ast_keywords(decorator) # find the path to write the sync class to sync_path = kwargs["path"] out_file = "/".join(sync_path.rsplit(".")[:-1]) + ".py" diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 3e67c6fe3..32ff5f49c 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -34,7 +34,7 @@ TEST_FAMILY = "test-family" TEST_FAMILY_2 = "test-family-2" -@CrossSync.sync_output(path="tests.system.data.test_system.TempRowBuilder") +@CrossSync.export_sync(path="tests.system.data.test_system.TempRowBuilder") class TempRowBuilderAsync: """ Used to add rows to a table for testing purposes. @@ -81,7 +81,7 @@ async def delete_rows(self): await self.table.client._gapic_client.mutate_rows(request) -@CrossSync.sync_output(path="tests.system.data.test_system.TestSystem") +@CrossSync.export_sync(path="tests.system.data.test_system.TestSystem") class TestSystemAsync: @CrossSync.convert(replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}) diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 8be885a92..292cbd692 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -28,7 +28,7 @@ import mock # type: ignore -@CrossSync.sync_output( +@CrossSync.export_sync( path="tests.unit.data._sync.test__mutate_rows.TestMutateRowsOperation", ) class TestMutateRowsOperation: diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 11fbc1bc1..076e86788 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -34,7 +34,7 @@ TEST_LABELS = ["label1", "label2"] -@CrossSync.sync_output( +@CrossSync.export_sync( path="tests.unit.data._sync.test__read_rows.TestReadRowsOperation", ) class TestReadRowsOperationAsync: diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 1b95956d7..0f5775fac 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -76,7 +76,7 @@ ) -@CrossSync.sync_output( +@CrossSync.export_sync( path="tests.unit.data._sync.test_client.TestBigtableDataClient", replace_symbols={ "TestTableAsync": "TestTable", @@ -1141,7 +1141,7 @@ def test_client_ctor_sync(self): assert client._channel_refresh_tasks == [] -@CrossSync.sync_output(path="tests.unit.data._sync.test_client.TestTable") +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestTable") class TestTableAsync: @CrossSync.convert( replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} @@ -1455,7 +1455,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ assert "app_profile_id=" not in goog_metadata -@CrossSync.sync_output(path="tests.unit.data._sync.test_client.TestReadRows") +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRows") class TestReadRowsAsync: """ Tests for table.read_rows and related methods. @@ -1973,7 +1973,7 @@ async def test_row_exists(self, return_value, expected_result): assert query.filter._to_dict() == expected_filter -@CrossSync.sync_output(path="tests.unit.data._sync.test_client.TestReadRowsSharded") +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRowsSharded") class TestReadRowsShardedAsync: @CrossSync.convert( replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} @@ -2198,7 +2198,7 @@ async def mock_call(*args, **kwargs): ) -@CrossSync.sync_output(path="tests.unit.data._sync.test_client.TestSampleRowKeys") +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestSampleRowKeys") class TestSampleRowKeysAsync: @CrossSync.convert( replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} @@ -2354,7 +2354,7 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio await table.sample_row_keys() -@CrossSync.sync_output( +@CrossSync.export_sync( path="tests.unit.data._sync.test_client.TestMutateRow", ) class TestMutateRowAsync: @@ -2535,7 +2535,7 @@ async def test_mutate_row_no_mutations(self, mutations): assert e.value.args[0] == "No mutations provided" -@CrossSync.sync_output( +@CrossSync.export_sync( path="tests.unit.data._sync.test_client.TestBulkMutateRows", ) class TestBulkMutateRowsAsync: @@ -2921,7 +2921,7 @@ async def test_bulk_mutate_error_recovery(self): await table.bulk_mutate_rows(entries, operation_timeout=1000) -@CrossSync.sync_output(path="tests.unit.data._sync.test_client.TestCheckAndMutateRow") +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestCheckAndMutateRow") class TestCheckAndMutateRowAsync: @CrossSync.convert( replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} @@ -3076,7 +3076,7 @@ async def test_check_and_mutate_mutations_parsing(self): ) -@CrossSync.sync_output(path="tests.unit.data._sync.test_client.TestReadModifyWriteRow") +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadModifyWriteRow") class TestReadModifyWriteRowAsync: @CrossSync.convert( replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 3d31eeacf..db3da531d 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -43,7 +43,7 @@ import mock # type: ignore -@CrossSync.sync_output(path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl") +@CrossSync.export_sync(path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl") class Test_FlowControl: @staticmethod @CrossSync.convert(replace_symbols={"_FlowControlAsync": "_FlowControl"}) @@ -319,7 +319,7 @@ async def test_add_to_flow_oversize(self): assert len(count_results) == 1 -@CrossSync.sync_output( +@CrossSync.export_sync( path="tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" ) class TestMutationsBatcherAsync: diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 5b87d81b0..901be0ea1 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -64,7 +64,7 @@ class TestFile(proto.Message): # noqa: F811 read_rows_tests = proto.RepeatedField(proto.MESSAGE, number=1, message=ReadRowsTest) -@CrossSync.sync_output( +@CrossSync.export_sync( path="tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", ) class TestReadRowsAcceptanceAsync: From 49b48080289fd6c96cb6ac6782c2b55230818085 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 14:25:10 -0700 Subject: [PATCH 143/360] import instead of duplicate --- .../data/_async/test_read_rows_acceptance.py | 29 ++----------------- .../data/_sync/test_read_rows_acceptance.py | 3 +- 2 files changed, 3 insertions(+), 29 deletions(-) diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 901be0ea1..7cdd2c180 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -26,44 +26,19 @@ from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.row import Row +from ...v2_client.test_row_merger import ReadRowsTest, TestFile + from google.cloud.bigtable.data._sync.cross_sync import CrossSync if CrossSync.is_async: from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data._async.client import BigtableDataClientAsync else: - from .._async.test_read_rows_acceptance import ReadRowsTest # noqa: F401 - from .._async.test_read_rows_acceptance import TestFile # noqa: F401 from google.cloud.bigtable.data._sync._read_rows import ( # noqa: F401 _ReadRowsOperation, ) from google.cloud.bigtable.data._sync.client import BigtableDataClient # noqa: F401 - -# TODO: autogenerate protos from -# https://github.com/googleapis/conformance-tests/blob/main/bigtable/v2/proto/google/cloud/conformance/bigtable/v2/tests.proto -class ReadRowsTest(proto.Message): # noqa: F811 - class Result(proto.Message): - row_key = proto.Field(proto.STRING, number=1) - family_name = proto.Field(proto.STRING, number=2) - qualifier = proto.Field(proto.STRING, number=3) - timestamp_micros = proto.Field(proto.INT64, number=4) - value = proto.Field(proto.STRING, number=5) - label = proto.Field(proto.STRING, number=6) - error = proto.Field(proto.BOOL, number=7) - - description = proto.Field(proto.STRING, number=1) - chunks = proto.RepeatedField( - proto.MESSAGE, number=2, message=ReadRowsResponse.CellChunk - ) - results = proto.RepeatedField(proto.MESSAGE, number=3, message=Result) - - -class TestFile(proto.Message): # noqa: F811 - __test__ = False - read_rows_tests = proto.RepeatedField(proto.MESSAGE, number=1, message=ReadRowsTest) - - @CrossSync.export_sync( path="tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", ) diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync/test_read_rows_acceptance.py index 3553d5fbd..6baef4a4d 100644 --- a/tests/unit/data/_sync/test_read_rows_acceptance.py +++ b/tests/unit/data/_sync/test_read_rows_acceptance.py @@ -22,13 +22,12 @@ from google.cloud.bigtable_v2 import ReadRowsResponse from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.row import Row +from ...v2_client.test_row_merger import ReadRowsTest, TestFile from google.cloud.bigtable.data._sync.cross_sync import CrossSync if CrossSync._Sync_Impl.is_async: pass else: - from .._async.test_read_rows_acceptance import ReadRowsTest - from .._async.test_read_rows_acceptance import TestFile from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation from google.cloud.bigtable.data._sync.client import BigtableDataClient From 18e4977c36b6a3d5102075325f7c35a9657e5664 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 14:27:16 -0700 Subject: [PATCH 144/360] import sync classes --- google/cloud/bigtable/data/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index cdb7622b6..fd44fe86c 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -20,10 +20,10 @@ from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync -# from google.cloud.bigtable.data._sync.client import BigtableDataClient -# from google.cloud.bigtable.data._sync.client import Table +from google.cloud.bigtable.data._sync.client import BigtableDataClient +from google.cloud.bigtable.data._sync.client import Table -# from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher +from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.read_rows_query import RowRange @@ -53,9 +53,9 @@ __version__: str = package_version.__version__ __all__ = ( - # "BigtableDataClient", - # "Table", - # "MutationsBatcher", + "BigtableDataClient", + "Table", + "MutationsBatcher", "BigtableDataClientAsync", "TableAsync", "MutationsBatcherAsync", From 48bb06f20da9f3d298dbbf7fc2ab208569108da9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 14:43:44 -0700 Subject: [PATCH 145/360] removed sync classes --- google/cloud/bigtable/data/__init__.py | 5 - google/cloud/bigtable/data/_sync/__init__.py | 0 .../cloud/bigtable/data/_sync/_mutate_rows.py | 184 -- .../cloud/bigtable/data/_sync/_read_rows.py | 309 -- google/cloud/bigtable/data/_sync/client.py | 1164 ------- .../bigtable/data/_sync/mutations_batcher.py | 451 --- .../cloud/bigtable/data/_sync/sync_gen.yaml | 62 - .../bigtable/data/_sync/system_tests.yaml | 25 - .../cloud/bigtable/data/_sync/transformers.py | 337 -- .../cloud/bigtable/data/_sync/unit_tests.yaml | 116 - tests/system/data/test_system.py | 812 ----- .../data/_async/test_mutations_batcher.py | 4 +- tests/unit/data/_sync/__init__.py | 0 tests/unit/data/_sync/test__mutate_rows.py | 321 -- tests/unit/data/_sync/test__read_rows.py | 360 --- tests/unit/data/_sync/test_client.py | 2740 ----------------- .../unit/data/_sync/test_mutations_batcher.py | 1104 ------- .../data/_sync/test_read_rows_acceptance.py | 333 -- 18 files changed, 3 insertions(+), 8324 deletions(-) delete mode 100644 google/cloud/bigtable/data/_sync/__init__.py delete mode 100644 google/cloud/bigtable/data/_sync/_mutate_rows.py delete mode 100644 google/cloud/bigtable/data/_sync/_read_rows.py delete mode 100644 google/cloud/bigtable/data/_sync/client.py delete mode 100644 google/cloud/bigtable/data/_sync/mutations_batcher.py delete mode 100644 google/cloud/bigtable/data/_sync/sync_gen.yaml delete mode 100644 google/cloud/bigtable/data/_sync/system_tests.yaml delete mode 100644 google/cloud/bigtable/data/_sync/transformers.py delete mode 100644 google/cloud/bigtable/data/_sync/unit_tests.yaml delete mode 100644 tests/system/data/test_system.py delete mode 100644 tests/unit/data/_sync/__init__.py delete mode 100644 tests/unit/data/_sync/test__mutate_rows.py delete mode 100644 tests/unit/data/_sync/test__read_rows.py delete mode 100644 tests/unit/data/_sync/test_client.py delete mode 100644 tests/unit/data/_sync/test_mutations_batcher.py delete mode 100644 tests/unit/data/_sync/test_read_rows_acceptance.py diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index fd44fe86c..b52d36b50 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -20,11 +20,6 @@ from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync -from google.cloud.bigtable.data._sync.client import BigtableDataClient -from google.cloud.bigtable.data._sync.client import Table - -from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.read_rows_query import RowRange from google.cloud.bigtable.data.row import Row diff --git a/google/cloud/bigtable/data/_sync/__init__.py b/google/cloud/bigtable/data/_sync/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py deleted file mode 100644 index d65cf3c61..000000000 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is automatically generated by CrossSync. Do not edit manually. -from __future__ import annotations -from typing import Sequence, TYPE_CHECKING -import functools -from google.api_core import exceptions as core_exceptions -from google.api_core import retry as retries -from google.cloud.bigtable.data._helpers import _make_metadata -from google.cloud.bigtable.data._helpers import _attempt_timeout_generator -from google.cloud.bigtable.data._helpers import _retry_exception_factory -from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete -from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup -from google.cloud.bigtable.data.exceptions import RetryExceptionGroup -from google.cloud.bigtable.data.exceptions import FailedMutationEntryError -from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -if not CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto -if TYPE_CHECKING: - from google.cloud.bigtable.data.mutations import RowMutationEntry - - if CrossSync._Sync_Impl.is_async: - pass - else: - from google.cloud.bigtable.data._sync.client import Table - from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient - - -class _MutateRowsOperation: - """ - MutateRowsOperation manages the logic of sending a set of row mutations, - and retrying on failed entries. It manages this using the _run_attempt - function, which attempts to mutate all outstanding entries, and raises - _MutateRowsIncomplete if any retryable errors are encountered. - - Errors are exposed as a MutationsExceptionGroup, which contains a list of - exceptions organized by the related failed mutation entries. - - Args: - gapic_client: the client to use for the mutate_rows call - table: the table associated with the request - mutation_entries: a list of RowMutationEntry objects to send to the server - operation_timeout: the timeout to use for the entire operation, in seconds. - attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds. - If not specified, the request will run until operation_timeout is reached. - """ - - def __init__( - self, - gapic_client: "BigtableClient", - table: "Table", - mutation_entries: list["RowMutationEntry"], - operation_timeout: float, - attempt_timeout: float | None, - retryable_exceptions: Sequence[type[Exception]] = (), - ): - total_mutations = sum((len(entry.mutations) for entry in mutation_entries)) - if total_mutations > _MUTATE_ROWS_REQUEST_MUTATION_LIMIT: - raise ValueError( - f"mutate_rows requests can contain at most {_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across all entries. Found {total_mutations}." - ) - metadata = _make_metadata(table.table_name, table.app_profile_id) - self._gapic_fn = functools.partial( - gapic_client.mutate_rows, - table_name=table.table_name, - app_profile_id=table.app_profile_id, - metadata=metadata, - retry=None, - ) - self.is_retryable = retries.if_exception_type( - *retryable_exceptions, _MutateRowsIncomplete - ) - sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - self._operation = lambda: CrossSync._Sync_Impl.retry_target( - self._run_attempt, - self.is_retryable, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, - ) - self.timeout_generator = _attempt_timeout_generator( - attempt_timeout, operation_timeout - ) - self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries] - self.remaining_indices = list(range(len(self.mutations))) - self.errors: dict[int, list[Exception]] = {} - - def start(self): - """Start the operation, and run until completion - - Raises: - MutationsExceptionGroup: if any mutations failed""" - try: - self._operation() - except Exception as exc: - incomplete_indices = self.remaining_indices.copy() - for idx in incomplete_indices: - self._handle_entry_error(idx, exc) - finally: - all_errors: list[Exception] = [] - for idx, exc_list in self.errors.items(): - if len(exc_list) == 0: - raise core_exceptions.ClientError( - f"Mutation {idx} failed with no associated errors" - ) - elif len(exc_list) == 1: - cause_exc = exc_list[0] - else: - cause_exc = RetryExceptionGroup(exc_list) - entry = self.mutations[idx].entry - all_errors.append(FailedMutationEntryError(idx, entry, cause_exc)) - if all_errors: - raise MutationsExceptionGroup(all_errors, len(self.mutations)) - - def _run_attempt(self): - """Run a single attempt of the mutate_rows rpc. - - Raises: - _MutateRowsIncomplete: if there are failed mutations eligible for - retry after the attempt is complete - GoogleAPICallError: if the gapic rpc fails""" - request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] - active_request_indices = { - req_idx: orig_idx for req_idx, orig_idx in enumerate(self.remaining_indices) - } - self.remaining_indices = [] - if not request_entries: - return - try: - result_generator = self._gapic_fn( - timeout=next(self.timeout_generator), - entries=request_entries, - retry=None, - ) - for result_list in result_generator: - for result in result_list.entries: - orig_idx = active_request_indices[result.index] - entry_error = core_exceptions.from_grpc_status( - result.status.code, - result.status.message, - details=result.status.details, - ) - if result.status.code != 0: - self._handle_entry_error(orig_idx, entry_error) - elif orig_idx in self.errors: - del self.errors[orig_idx] - del active_request_indices[result.index] - except Exception as exc: - for idx in active_request_indices.values(): - self._handle_entry_error(idx, exc) - raise - if self.remaining_indices: - raise _MutateRowsIncomplete - - def _handle_entry_error(self, idx: int, exc: Exception): - """Add an exception to the list of exceptions for a given mutation index, - and add the index to the list of remaining indices if the exception is - retryable. - - Args: - idx: the index of the mutation that failed - exc: the exception to add to the list""" - entry = self.mutations[idx].entry - self.errors.setdefault(idx, []).append(exc) - if ( - entry.is_idempotent() - and self.is_retryable(exc) - and (idx not in self.remaining_indices) - ): - self.remaining_indices.append(idx) diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py deleted file mode 100644 index 32603a17b..000000000 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is automatically generated by CrossSync. Do not edit manually. -from __future__ import annotations -from typing import TYPE_CHECKING, Sequence -from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB -from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB -from google.cloud.bigtable_v2.types import RowSet as RowSetPB -from google.cloud.bigtable_v2.types import RowRange as RowRangePB -from google.cloud.bigtable.data.row import Row, Cell -from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery -from google.cloud.bigtable.data.exceptions import InvalidChunk -from google.cloud.bigtable.data.exceptions import _RowSetComplete -from google.cloud.bigtable.data import _helpers -from google.api_core import retry as retries -from google.api_core.retry import exponential_sleep_generator -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -if not CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable.data._async._read_rows import _ResetRow -if TYPE_CHECKING: - if CrossSync._Sync_Impl.is_async: - pass - else: - from google.cloud.bigtable.data._sync.client import Table - - -class _ReadRowsOperation: - """ - ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream - into a stream of Row objects. - - ReadRowsOperation.merge_row_response_stream takes in a stream of ReadRowsResponse - and turns them into a stream of Row objects using an internal - StateMachine. - - ReadRowsOperation(request, client) handles row merging logic end-to-end, including - performing retries on stream errors. - - Args: - query: The query to execute - table: The table to send the request to - operation_timeout: The total time to allow for the operation, in seconds - attempt_timeout: The time to allow for each individual attempt, in seconds - retryable_exceptions: A list of exceptions that should trigger a retry - """ - - __slots__ = ( - "attempt_timeout_gen", - "operation_timeout", - "request", - "table", - "_predicate", - "_metadata", - "_last_yielded_row_key", - "_remaining_count", - ) - - def __init__( - self, - query: ReadRowsQuery, - table: "Table", - operation_timeout: float, - attempt_timeout: float, - retryable_exceptions: Sequence[type[Exception]] = (), - ): - self.attempt_timeout_gen = _helpers._attempt_timeout_generator( - attempt_timeout, operation_timeout - ) - self.operation_timeout = operation_timeout - if isinstance(query, dict): - self.request = ReadRowsRequestPB( - **query, - table_name=table.table_name, - app_profile_id=table.app_profile_id, - ) - else: - self.request = query._to_pb(table) - self.table = table - self._predicate = retries.if_exception_type(*retryable_exceptions) - self._metadata = _helpers._make_metadata(table.table_name, table.app_profile_id) - self._last_yielded_row_key: bytes | None = None - self._remaining_count: int | None = self.request.rows_limit or None - - def start_operation(self) -> CrossSync._Sync_Impl.Iterable[Row]: - """Start the read_rows operation, retrying on retryable errors. - - Yields: - Row: The next row in the stream""" - return CrossSync._Sync_Impl.retry_target_stream( - self._read_rows_attempt, - self._predicate, - exponential_sleep_generator(0.01, 60, multiplier=2), - self.operation_timeout, - exception_factory=_helpers._retry_exception_factory, - ) - - def _read_rows_attempt(self) -> CrossSync._Sync_Impl.Iterable[Row]: - """Attempt a single read_rows rpc call. - This function is intended to be wrapped by retry logic, - which will call this function until it succeeds or - a non-retryable error is raised. - - Yields: - Row: The next row in the stream""" - if self._last_yielded_row_key is not None: - try: - self.request.rows = self._revise_request_rowset( - row_set=self.request.rows, - last_seen_row_key=self._last_yielded_row_key, - ) - except _RowSetComplete: - return self.merge_rows(None) - if self._remaining_count is not None: - self.request.rows_limit = self._remaining_count - if self._remaining_count == 0: - return self.merge_rows(None) - gapic_stream = self.table.client._gapic_client.read_rows( - self.request, - timeout=next(self.attempt_timeout_gen), - metadata=self._metadata, - retry=None, - ) - chunked_stream = self.chunk_stream(gapic_stream) - return self.merge_rows(chunked_stream) - - def chunk_stream( - self, - stream: CrossSync._Sync_Impl.Awaitable[ - CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB] - ], - ) -> CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB.CellChunk]: - """process chunks out of raw read_rows stream - - Args: - stream: the raw read_rows stream from the gapic client - Yields: - ReadRowsResponsePB.CellChunk: the next chunk in the stream""" - for resp in stream: - resp = resp._pb - if resp.last_scanned_row_key: - if ( - self._last_yielded_row_key is not None - and resp.last_scanned_row_key <= self._last_yielded_row_key - ): - raise InvalidChunk("last scanned out of order") - self._last_yielded_row_key = resp.last_scanned_row_key - current_key = None - for c in resp.chunks: - if current_key is None: - current_key = c.row_key - if current_key is None: - raise InvalidChunk("first chunk is missing a row key") - elif ( - self._last_yielded_row_key - and current_key <= self._last_yielded_row_key - ): - raise InvalidChunk("row keys should be strictly increasing") - yield c - if c.reset_row: - current_key = None - elif c.commit_row: - self._last_yielded_row_key = current_key - if self._remaining_count is not None: - self._remaining_count -= 1 - if self._remaining_count < 0: - raise InvalidChunk("emit count exceeds row limit") - current_key = None - - @staticmethod - def merge_rows( - chunks: CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB.CellChunk] | None, - ) -> CrossSync._Sync_Impl.Iterable[Row]: - """Merge chunks into rows - - Args: - chunks: the chunk stream to merge - Yields: - Row: the next row in the stream""" - if chunks is None: - return - it = chunks.__iter__() - while True: - try: - c = it.__next__() - except CrossSync._Sync_Impl.StopIteration: - return - row_key = c.row_key - if not row_key: - raise InvalidChunk("first row chunk is missing key") - cells = [] - family: str | None = None - qualifier: bytes | None = None - try: - while True: - if c.reset_row: - raise _ResetRow(c) - k = c.row_key - f = c.family_name.value - q = c.qualifier.value if c.HasField("qualifier") else None - if k and k != row_key: - raise InvalidChunk("unexpected new row key") - if f: - family = f - if q is not None: - qualifier = q - else: - raise InvalidChunk("new family without qualifier") - elif family is None: - raise InvalidChunk("missing family") - elif q is not None: - if family is None: - raise InvalidChunk("new qualifier without family") - qualifier = q - elif qualifier is None: - raise InvalidChunk("missing qualifier") - ts = c.timestamp_micros - labels = c.labels if c.labels else [] - value = c.value - if c.value_size > 0: - buffer = [value] - while c.value_size > 0: - c = it.__next__() - t = c.timestamp_micros - cl = c.labels - k = c.row_key - if ( - c.HasField("family_name") - and c.family_name.value != family - ): - raise InvalidChunk("family changed mid cell") - if ( - c.HasField("qualifier") - and c.qualifier.value != qualifier - ): - raise InvalidChunk("qualifier changed mid cell") - if t and t != ts: - raise InvalidChunk("timestamp changed mid cell") - if cl and cl != labels: - raise InvalidChunk("labels changed mid cell") - if k and k != row_key: - raise InvalidChunk("row key changed mid cell") - if c.reset_row: - raise _ResetRow(c) - buffer.append(c.value) - value = b"".join(buffer) - cells.append( - Cell(value, row_key, family, qualifier, ts, list(labels)) - ) - if c.commit_row: - yield Row(row_key, cells) - break - c = it.__next__() - except _ResetRow as e: - c = e.chunk - if ( - c.row_key - or c.HasField("family_name") - or c.HasField("qualifier") - or c.timestamp_micros - or c.labels - or c.value - ): - raise InvalidChunk("reset row with data") - continue - except CrossSync._Sync_Impl.StopIteration: - raise InvalidChunk("premature end of stream") - - @staticmethod - def _revise_request_rowset(row_set: RowSetPB, last_seen_row_key: bytes) -> RowSetPB: - """Revise the rows in the request to avoid ones we've already processed. - - Args: - row_set: the row set from the request - last_seen_row_key: the last row key encountered - Returns: - RowSetPB: the new rowset after adusting for the last seen key - Raises: - _RowSetComplete: if there are no rows left to process after the revision""" - if row_set is None or (not row_set.row_ranges and (not row_set.row_keys)): - last_seen = last_seen_row_key - return RowSetPB(row_ranges=[RowRangePB(start_key_open=last_seen)]) - adjusted_keys: list[bytes] = [ - k for k in row_set.row_keys if k > last_seen_row_key - ] - adjusted_ranges: list[RowRangePB] = [] - for row_range in row_set.row_ranges: - end_key = row_range.end_key_closed or row_range.end_key_open or None - if end_key is None or end_key > last_seen_row_key: - new_range = RowRangePB(row_range) - start_key = row_range.start_key_closed or row_range.start_key_open - if start_key is None or start_key <= last_seen_row_key: - new_range.start_key_open = last_seen_row_key - adjusted_ranges.append(new_range) - if len(adjusted_keys) == 0 and len(adjusted_ranges) == 0: - raise _RowSetComplete() - return RowSetPB(row_keys=adjusted_keys, row_ranges=adjusted_ranges) diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py deleted file mode 100644 index 833573365..000000000 --- a/google/cloud/bigtable/data/_sync/client.py +++ /dev/null @@ -1,1164 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is automatically generated by CrossSync. Do not edit manually. -from __future__ import annotations -from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING -import asyncio -import time -import warnings -import random -import os -import concurrent.futures -from functools import partial -from grpc import Channel -from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.transports.base import ( - DEFAULT_CLIENT_INFO, -) -from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest -from google.cloud.client import ClientWithProject -from google.cloud.environment_vars import BIGTABLE_EMULATOR -from google.api_core import retry as retries -from google.api_core.exceptions import DeadlineExceeded -from google.api_core.exceptions import ServiceUnavailable -from google.api_core.exceptions import Aborted -import google.auth.credentials -import google.auth._default -from google.api_core import client_options as client_options_lib -from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT -from google.cloud.bigtable.data.row import Row -from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery -from google.cloud.bigtable.data.exceptions import FailedQueryShardError -from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup -from google.cloud.bigtable.data import _helpers -from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data._helpers import _MB_SIZE -from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry -from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule -from google.cloud.bigtable.data.row_filters import RowFilter -from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter -from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter -from google.cloud.bigtable.data.row_filters import RowFilterChain -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -if CrossSync._Sync_Impl.is_async: - pass -else: - from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation - from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledBigtableGrpcTransport, - ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledChannel, - ) - from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient - from typing import Iterable -if TYPE_CHECKING: - from google.cloud.bigtable.data._helpers import RowKeySamples - from google.cloud.bigtable.data._helpers import ShardedQuery - - -class BigtableDataClient(ClientWithProject): - def __init__( - self, - *, - project: str | None = None, - pool_size: int = 3, - credentials: google.auth.credentials.Credentials | None = None, - client_options: dict[str, Any] - | "google.api_core.client_options.ClientOptions" - | None = None, - ): - """Create a client instance for the Bigtable Data API - - Client should be created within an async context (running event loop) - - Args: - project: the project which the client acts on behalf of. - If not passed, falls back to the default inferred - from the environment. - pool_size: The number of grpc channels to maintain - in the internal channel pool. - credentials: - Thehe OAuth2 Credentials to use for this - client. If not passed (and if no ``_http`` object is - passed), falls back to the default inferred from the - environment. - client_options: - Client options used to set user options - on the client. API Endpoint should be set through client_options. - Raises: - RuntimeError: if called outside of an async context (no running event loop) - ValueError: if pool_size is less than 1""" - transport_str = f"bt-{self._client_version()}-{pool_size}" - transport = PooledBigtableGrpcTransport.with_fixed_size(pool_size) - BigtableClientMeta._transport_registry[transport_str] = transport - client_info = DEFAULT_CLIENT_INFO - client_info.client_library_version = self._client_version() - if type(client_options) is dict: - client_options = client_options_lib.from_dict(client_options) - client_options = cast( - Optional[client_options_lib.ClientOptions], client_options - ) - self._emulator_host = os.getenv(BIGTABLE_EMULATOR) - if self._emulator_host is not None: - if credentials is None: - credentials = google.auth.credentials.AnonymousCredentials() - if project is None: - project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT - ClientWithProject.__init__( - self, - credentials=credentials, - project=project, - client_options=client_options, - ) - self._gapic_client = BigtableClient( - transport=transport_str, - credentials=credentials, - client_options=client_options, - client_info=client_info, - ) - self._is_closed = CrossSync._Sync_Impl.Event() - self.transport = cast(PooledBigtableGrpcTransport, self._gapic_client.transport) - self._active_instances: Set[_helpers._WarmedInstanceKey] = set() - self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} - self._channel_init_time = time.monotonic() - self._channel_refresh_tasks: list[CrossSync._Sync_Impl.Task[None]] = [] - self._executor = ( - concurrent.futures.ThreadPoolExecutor() - if not CrossSync._Sync_Impl.is_async - else None - ) - if self._emulator_host is not None: - warnings.warn( - "Connecting to Bigtable emulator at {}".format(self._emulator_host), - RuntimeWarning, - stacklevel=2, - ) - self.transport._grpc_channel = PooledChannel( - pool_size=pool_size, host=self._emulator_host, insecure=True - ) - self.transport._stubs = {} - self.transport._prep_wrapped_messages(client_info) - else: - try: - self._start_background_channel_refresh() - except RuntimeError: - warnings.warn( - f"{self.__class__.__name__} should be started in an asyncio event loop. Channel refresh will not be started", - RuntimeWarning, - stacklevel=2, - ) - - @staticmethod - def _client_version() -> str: - """Helper function to return the client version string for this client""" - if CrossSync._Sync_Impl.is_async: - return f"{google.cloud.bigtable.__version__}-data-async" - else: - return f"{google.cloud.bigtable.__version__}-data" - - def _start_background_channel_refresh(self) -> None: - """Starts a background task to ping and warm each channel in the pool - - Raises: - RuntimeError: if not called in an asyncio event loop""" - if ( - not self._channel_refresh_tasks - and (not self._emulator_host) - and (not self._is_closed.is_set()) - ): - if CrossSync._Sync_Impl.is_async: - asyncio.get_running_loop() - for channel_idx in range(self.transport.pool_size): - refresh_task = CrossSync._Sync_Impl.create_task( - self._manage_channel, - channel_idx, - sync_executor=self._executor, - task_name=f"{self.__class__.__name__} channel refresh {channel_idx}", - ) - self._channel_refresh_tasks.append(refresh_task) - refresh_task.add_done_callback( - lambda _: self._channel_refresh_tasks.remove(refresh_task) - if refresh_task in self._channel_refresh_tasks - else None - ) - - def close(self, timeout: float | None = None): - """Cancel all background tasks""" - self._is_closed.set() - for task in self._channel_refresh_tasks: - task.cancel() - self.transport.close() - if self._executor: - self._executor.shutdown(wait=False) - CrossSync._Sync_Impl.wait(self._channel_refresh_tasks, timeout=timeout) - - def _ping_and_warm_instances( - self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None - ) -> list[BaseException | None]: - """Prepares the backend for requests on a channel - - Pings each Bigtable instance registered in `_active_instances` on the client - - Args: - channel: grpc channel to warm - instance_key: if provided, only warm the instance associated with the key - Returns: - list[BaseException | None]: sequence of results or exceptions from the ping requests - """ - instance_list = ( - [instance_key] if instance_key is not None else self._active_instances - ) - ping_rpc = channel.unary_unary( - "/google.bigtable.v2.Bigtable/PingAndWarm", - request_serializer=PingAndWarmRequest.serialize, - ) - partial_list = [ - partial( - ping_rpc, - request={"name": instance_name, "app_profile_id": app_profile_id}, - metadata=[ - ( - "x-goog-request-params", - f"name={instance_name}&app_profile_id={app_profile_id}", - ) - ], - wait_for_ready=True, - ) - for instance_name, table_name, app_profile_id in instance_list - ] - result_list = CrossSync._Sync_Impl.gather_partials( - partial_list, return_exceptions=True, sync_executor=self._executor - ) - return [r or None for r in result_list] - - def _manage_channel( - self, - channel_idx: int, - refresh_interval_min: float = 60 * 35, - refresh_interval_max: float = 60 * 45, - grace_period: float = 60 * 10, - ) -> None: - """Background coroutine that periodically refreshes and warms a grpc channel - - The backend will automatically close channels after 60 minutes, so - `refresh_interval` + `grace_period` should be < 60 minutes - - Runs continuously until the client is closed - - Args: - channel_idx: index of the channel in the transport's channel pool - refresh_interval_min: minimum interval before initiating refresh - process in seconds. Actual interval will be a random value - between `refresh_interval_min` and `refresh_interval_max` - refresh_interval_max: maximum interval before initiating refresh - process in seconds. Actual interval will be a random value - between `refresh_interval_min` and `refresh_interval_max` - grace_period: time to allow previous channel to serve existing - requests before closing, in seconds""" - first_refresh = self._channel_init_time + random.uniform( - refresh_interval_min, refresh_interval_max - ) - next_sleep = max(first_refresh - time.monotonic(), 0) - if next_sleep > 0: - channel = self.transport.channels[channel_idx] - self._ping_and_warm_instances(channel) - while not self._is_closed.is_set(): - CrossSync._Sync_Impl.event_wait( - self._is_closed, next_sleep, async_break_early=False - ) - if self._is_closed.is_set(): - break - new_channel = self.transport.grpc_channel._create_channel() - self._ping_and_warm_instances(new_channel) - start_timestamp = time.monotonic() - self.transport.replace_channel( - channel_idx, - grace=grace_period, - new_channel=new_channel, - event=self._is_closed, - ) - next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) - next_sleep = next_refresh - (time.monotonic() - start_timestamp) - - def _register_instance(self, instance_id: str, owner: Table) -> None: - """Registers an instance with the client, and warms the channel pool - for the instance - The client will periodically refresh grpc channel pool used to make - requests, and new channels will be warmed for each registered instance - Channels will not be refreshed unless at least one instance is registered - - Args: - instance_id: id of the instance to register. - owner: table that owns the instance. Owners will be tracked in - _instance_owners, and instances will only be unregistered when all - owners call _remove_instance_registration""" - instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _helpers._WarmedInstanceKey( - instance_name, owner.table_name, owner.app_profile_id - ) - self._instance_owners.setdefault(instance_key, set()).add(id(owner)) - if instance_name not in self._active_instances: - self._active_instances.add(instance_key) - if self._channel_refresh_tasks: - for channel in self.transport.channels: - self._ping_and_warm_instances(channel, instance_key) - else: - self._start_background_channel_refresh() - - def _remove_instance_registration(self, instance_id: str, owner: Table) -> bool: - """Removes an instance from the client's registered instances, to prevent - warming new channels for the instance - - If instance_id is not registered, or is still in use by other tables, returns False - - Args: - instance_id: id of the instance to remove - owner: table that owns the instance. Owners will be tracked in - _instance_owners, and instances will only be unregistered when all - owners call _remove_instance_registration - Returns: - bool: True if instance was removed, else False""" - instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _helpers._WarmedInstanceKey( - instance_name, owner.table_name, owner.app_profile_id - ) - owner_list = self._instance_owners.get(instance_key, set()) - try: - owner_list.remove(id(owner)) - if len(owner_list) == 0: - self._active_instances.remove(instance_key) - return True - except KeyError: - return False - - def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> Table: - """Returns a table instance for making data API requests. All arguments are passed - directly to the Table constructor. - - Args: - instance_id: The Bigtable instance ID to associate with this client. - instance_id is combined with the client's project to fully - specify the instance - table_id: The ID of the table. table_id is combined with the - instance_id and the client's project to fully specify the table - app_profile_id: The app profile to associate with requests. - https://cloud.google.com/bigtable/docs/app-profiles - default_read_rows_operation_timeout: The default timeout for read rows - operations, in seconds. If not set, defaults to 600 seconds (10 minutes) - default_read_rows_attempt_timeout: The default timeout for individual - read rows rpc requests, in seconds. If not set, defaults to 20 seconds - default_mutate_rows_operation_timeout: The default timeout for mutate rows - operations, in seconds. If not set, defaults to 600 seconds (10 minutes) - default_mutate_rows_attempt_timeout: The default timeout for individual - mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds - default_operation_timeout: The default timeout for all other operations, in - seconds. If not set, defaults to 60 seconds - default_attempt_timeout: The default timeout for all other individual rpc - requests, in seconds. If not set, defaults to 20 seconds - default_read_rows_retryable_errors: a list of errors that will be retried - if encountered during read_rows and related operations. - Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) - default_mutate_rows_retryable_errors: a list of errors that will be retried - if encountered during mutate_rows and related operations. - Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) - default_retryable_errors: a list of errors that will be retried if - encountered during all other operations. - Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) - Returns: - TableAsync: a table instance for making data API requests - Raises: - RuntimeError: if called outside of an async context (no running event loop) - """ - return Table(self, instance_id, table_id, *args, **kwargs) - - def __enter__(self): - self._start_background_channel_refresh() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - self._gapic_client.__exit__(exc_type, exc_val, exc_tb) - - -class Table: - """ - Main Data API surface - - Table object maintains table_id, and app_profile_id context, and passes them with - each call - """ - - def __init__( - self, - client: BigtableDataClient, - instance_id: str, - table_id: str, - app_profile_id: str | None = None, - *, - default_read_rows_operation_timeout: float = 600, - default_read_rows_attempt_timeout: float | None = 20, - default_mutate_rows_operation_timeout: float = 600, - default_mutate_rows_attempt_timeout: float | None = 60, - default_operation_timeout: float = 60, - default_attempt_timeout: float | None = 20, - default_read_rows_retryable_errors: Sequence[type[Exception]] = ( - DeadlineExceeded, - ServiceUnavailable, - Aborted, - ), - default_mutate_rows_retryable_errors: Sequence[type[Exception]] = ( - DeadlineExceeded, - ServiceUnavailable, - ), - default_retryable_errors: Sequence[type[Exception]] = ( - DeadlineExceeded, - ServiceUnavailable, - ), - ): - """Initialize a Table instance - - Must be created within an async context (running event loop) - - Args: - instance_id: The Bigtable instance ID to associate with this client. - instance_id is combined with the client's project to fully - specify the instance - table_id: The ID of the table. table_id is combined with the - instance_id and the client's project to fully specify the table - app_profile_id: The app profile to associate with requests. - https://cloud.google.com/bigtable/docs/app-profiles - default_read_rows_operation_timeout: The default timeout for read rows - operations, in seconds. If not set, defaults to 600 seconds (10 minutes) - default_read_rows_attempt_timeout: The default timeout for individual - read rows rpc requests, in seconds. If not set, defaults to 20 seconds - default_mutate_rows_operation_timeout: The default timeout for mutate rows - operations, in seconds. If not set, defaults to 600 seconds (10 minutes) - default_mutate_rows_attempt_timeout: The default timeout for individual - mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds - default_operation_timeout: The default timeout for all other operations, in - seconds. If not set, defaults to 60 seconds - default_attempt_timeout: The default timeout for all other individual rpc - requests, in seconds. If not set, defaults to 20 seconds - default_read_rows_retryable_errors: a list of errors that will be retried - if encountered during read_rows and related operations. - Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) - default_mutate_rows_retryable_errors: a list of errors that will be retried - if encountered during mutate_rows and related operations. - Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) - default_retryable_errors: a list of errors that will be retried if - encountered during all other operations. - Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) - Raises: - RuntimeError: if called outside of an async context (no running event loop) - """ - _helpers._validate_timeouts( - default_operation_timeout, default_attempt_timeout, allow_none=True - ) - _helpers._validate_timeouts( - default_read_rows_operation_timeout, - default_read_rows_attempt_timeout, - allow_none=True, - ) - _helpers._validate_timeouts( - default_mutate_rows_operation_timeout, - default_mutate_rows_attempt_timeout, - allow_none=True, - ) - self.client = client - self.instance_id = instance_id - self.instance_name = self.client._gapic_client.instance_path( - self.client.project, instance_id - ) - self.table_id = table_id - self.table_name = self.client._gapic_client.table_path( - self.client.project, instance_id, table_id - ) - self.app_profile_id = app_profile_id - self.default_operation_timeout = default_operation_timeout - self.default_attempt_timeout = default_attempt_timeout - self.default_read_rows_operation_timeout = default_read_rows_operation_timeout - self.default_read_rows_attempt_timeout = default_read_rows_attempt_timeout - self.default_mutate_rows_operation_timeout = ( - default_mutate_rows_operation_timeout - ) - self.default_mutate_rows_attempt_timeout = default_mutate_rows_attempt_timeout - self.default_read_rows_retryable_errors = ( - default_read_rows_retryable_errors or () - ) - self.default_mutate_rows_retryable_errors = ( - default_mutate_rows_retryable_errors or () - ) - self.default_retryable_errors = default_retryable_errors or () - try: - self._register_instance_future = CrossSync._Sync_Impl.create_task( - self.client._register_instance, - self.instance_id, - self, - sync_executor=self.client._executor, - ) - except RuntimeError as e: - raise RuntimeError( - f"{self.__class__.__name__} must be created within an async event loop context." - ) from e - - def read_rows_stream( - self, - query: ReadRowsQuery, - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - ) -> Iterable[Row]: - """Read a set of rows from the table, based on the specified query. - Returns an iterator to asynchronously stream back row data. - - Failed requests within operation_timeout will be retried based on the - retryable_errors list until operation_timeout is reached. - - Args: - query: contains details about which rows to return - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_read_rows_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_read_rows_attempt_timeout. - If None, defaults to operation_timeout. - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_read_rows_retryable_errors - Returns: - AsyncIterable[Row]: an asynchronous iterator that yields rows returned by the query - Raises: - google.api_core.exceptions.DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions - from any retries that failed - google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error - """ - operation_timeout, attempt_timeout = _helpers._get_timeouts( - operation_timeout, attempt_timeout, self - ) - retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - row_merger = _ReadRowsOperation( - query, - self, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_exceptions=retryable_excs, - ) - return row_merger.start_operation() - - def read_rows( - self, - query: ReadRowsQuery, - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - ) -> list[Row]: - """Read a set of rows from the table, based on the specified query. - Retruns results as a list of Row objects when the request is complete. - For streamed results, use read_rows_stream. - - Failed requests within operation_timeout will be retried based on the - retryable_errors list until operation_timeout is reached. - - Args: - query: contains details about which rows to return - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_read_rows_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_read_rows_attempt_timeout. - If None, defaults to operation_timeout. - If None, defaults to the Table's default_read_rows_attempt_timeout, - or the operation_timeout if that is also None. - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_read_rows_retryable_errors. - Returns: - list[Row]: a list of Rows returned by the query - Raises: - google.api_core.exceptions.DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions - from any retries that failed - google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error - """ - row_generator = self.read_rows_stream( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, - ) - return [row for row in row_generator] - - def read_row( - self, - row_key: str | bytes, - *, - row_filter: RowFilter | None = None, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - ) -> Row | None: - """Read a single row from the table, based on the specified key. - - Failed requests within operation_timeout will be retried based on the - retryable_errors list until operation_timeout is reached. - - Args: - query: contains details about which rows to return - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_read_rows_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_read_rows_attempt_timeout. - If None, defaults to operation_timeout. - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_read_rows_retryable_errors. - Returns: - Row | None: a Row object if the row exists, otherwise None - Raises: - google.api_core.exceptions.DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions - from any retries that failed - google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error - """ - if row_key is None: - raise ValueError("row_key must be string or bytes") - query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) - results = self.read_rows( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, - ) - if len(results) == 0: - return None - return results[0] - - def read_rows_sharded( - self, - sharded_query: ShardedQuery, - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - ) -> list[Row]: - """Runs a sharded query in parallel, then return the results in a single list. - Results will be returned in the order of the input queries. - - This function is intended to be run on the results on a query.shard() call. - For example:: - - table_shard_keys = await table.sample_row_keys() - query = ReadRowsQuery(...) - shard_queries = query.shard(table_shard_keys) - results = await table.read_rows_sharded(shard_queries) - - Args: - sharded_query: a sharded query to execute - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_read_rows_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_read_rows_attempt_timeout. - If None, defaults to operation_timeout. - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_read_rows_retryable_errors. - Returns: - list[Row]: a list of Rows returned by the query - Raises: - ShardedReadRowsExceptionGroup: if any of the queries failed - ValueError: if the query_list is empty""" - if not sharded_query: - raise ValueError("empty sharded_query") - operation_timeout, attempt_timeout = _helpers._get_timeouts( - operation_timeout, attempt_timeout, self - ) - rpc_timeout_generator = _helpers._attempt_timeout_generator( - operation_timeout, operation_timeout - ) - concurrency_sem = CrossSync._Sync_Impl.Semaphore(_helpers._CONCURRENCY_LIMIT) - - def read_rows_with_semaphore(query): - with concurrency_sem: - shard_timeout = next(rpc_timeout_generator) - if shard_timeout <= 0: - raise DeadlineExceeded( - "Operation timeout exceeded before starting query" - ) - return self.read_rows( - query, - operation_timeout=shard_timeout, - attempt_timeout=min(attempt_timeout, shard_timeout), - retryable_errors=retryable_errors, - ) - - routine_list = [ - partial(read_rows_with_semaphore, query) for query in sharded_query - ] - batch_result = CrossSync._Sync_Impl.gather_partials( - routine_list, return_exceptions=True, sync_executor=self.client._executor - ) - error_dict = {} - shard_idx = 0 - results_list = [] - for result in batch_result: - if isinstance(result, Exception): - error_dict[shard_idx] = result - elif isinstance(result, BaseException): - raise result - else: - results_list.extend(result) - shard_idx += 1 - if error_dict: - raise ShardedReadRowsExceptionGroup( - [ - FailedQueryShardError(idx, sharded_query[idx], e) - for idx, e in error_dict.items() - ], - results_list, - len(sharded_query), - ) - return results_list - - def row_exists( - self, - row_key: str | bytes, - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - ) -> bool: - """Return a boolean indicating whether the specified row exists in the table. - uses the filters: chain(limit cells per row = 1, strip value) - - Args: - row_key: the key of the row to check - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_read_rows_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_read_rows_attempt_timeout. - If None, defaults to operation_timeout. - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_read_rows_retryable_errors. - Returns: - bool: a bool indicating whether the row exists - Raises: - google.api_core.exceptions.DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions - from any retries that failed - google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error - """ - if row_key is None: - raise ValueError("row_key must be string or bytes") - strip_filter = StripValueTransformerFilter(flag=True) - limit_filter = CellsRowLimitFilter(1) - chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) - query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) - results = self.read_rows( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, - ) - return len(results) > 0 - - def sample_row_keys( - self, - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - ) -> RowKeySamples: - """Return a set of RowKeySamples that delimit contiguous sections of the table of - approximately equal size - - RowKeySamples output can be used with ReadRowsQuery.shard() to create a sharded query that - can be parallelized across multiple backend nodes read_rows and read_rows_stream - requests will call sample_row_keys internally for this purpose when sharding is enabled - - RowKeySamples is simply a type alias for list[tuple[bytes, int]]; a list of - row_keys, along with offset positions in the table - - Args: - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget.i - Defaults to the Table's default_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_attempt_timeout. - If None, defaults to operation_timeout. - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_retryable_errors. - Returns: - RowKeySamples: a set of RowKeySamples the delimit contiguous sections of the table - Raises: - google.api_core.exceptions.DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions - from any retries that failed - google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error - """ - operation_timeout, attempt_timeout = _helpers._get_timeouts( - operation_timeout, attempt_timeout, self - ) - attempt_timeout_gen = _helpers._attempt_timeout_generator( - attempt_timeout, operation_timeout - ) - retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - predicate = retries.if_exception_type(*retryable_excs) - sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) - - def execute_rpc(): - results = self.client._gapic_client.sample_row_keys( - table_name=self.table_name, - app_profile_id=self.app_profile_id, - timeout=next(attempt_timeout_gen), - metadata=metadata, - retry=None, - ) - return [(s.row_key, s.offset_bytes) for s in results] - - return CrossSync._Sync_Impl.retry_target( - execute_rpc, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_helpers._retry_exception_factory, - ) - - def mutations_batcher( - self, - *, - flush_interval: float | None = 5, - flush_limit_mutation_count: int | None = 1000, - flush_limit_bytes: int = 20 * _MB_SIZE, - flow_control_max_mutation_count: int = 100000, - flow_control_max_bytes: int = 100 * _MB_SIZE, - batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - batch_retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - ) -> MutationsBatcher: - """Returns a new mutations batcher instance. - - Can be used to iteratively add mutations that are flushed as a group, - to avoid excess network calls - - Args: - flush_interval: Automatically flush every flush_interval seconds. If None, - a table default will be used - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count - mutations are added across all entries. If None, this limit is ignored. - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. - flow_control_max_mutation_count: Maximum number of inflight mutations. - flow_control_max_bytes: Maximum number of inflight bytes. - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. - Defaults to the Table's default_mutate_rows_operation_timeout - batch_attempt_timeout: timeout for each individual request, in seconds. - Defaults to the Table's default_mutate_rows_attempt_timeout. - If None, defaults to batch_operation_timeout. - batch_retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_mutate_rows_retryable_errors. - Returns: - MutationsBatcherAsync: a MutationsBatcher context manager that can batch requests - """ - return MutationsBatcher( - self, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_mutation_count, - flush_limit_bytes=flush_limit_bytes, - flow_control_max_mutation_count=flow_control_max_mutation_count, - flow_control_max_bytes=flow_control_max_bytes, - batch_operation_timeout=batch_operation_timeout, - batch_attempt_timeout=batch_attempt_timeout, - batch_retryable_errors=batch_retryable_errors, - ) - - def mutate_row( - self, - row_key: str | bytes, - mutations: list[Mutation] | Mutation, - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - ): - """Mutates a row atomically. - - Cells already present in the row are left unchanged unless explicitly changed - by ``mutation``. - - Idempotent operations (i.e, all mutations have an explicit timestamp) will be - retried on server failure. Non-idempotent operations will not. - - Args: - row_key: the row to apply mutations to - mutations: the set of mutations to apply to the row - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_attempt_timeout. - If None, defaults to operation_timeout. - retryable_errors: a list of errors that will be retried if encountered. - Only idempotent mutations will be retried. Defaults to the Table's - default_retryable_errors. - Raises: - google.api_core.exceptions.DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing all - GoogleAPIError exceptions from any retries that failed - google.api_core.exceptions.GoogleAPIError: raised on non-idempotent operations that cannot be - safely retried. - ValueError: if invalid arguments are provided""" - operation_timeout, attempt_timeout = _helpers._get_timeouts( - operation_timeout, attempt_timeout, self - ) - if not mutations: - raise ValueError("No mutations provided") - mutations_list = mutations if isinstance(mutations, list) else [mutations] - if all((mutation.is_idempotent() for mutation in mutations_list)): - predicate = retries.if_exception_type( - *_helpers._get_retryable_errors(retryable_errors, self) - ) - else: - predicate = retries.if_exception_type() - sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - target = partial( - self.client._gapic_client.mutate_row, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - mutations=[mutation._to_pb() for mutation in mutations_list], - table_name=self.table_name, - app_profile_id=self.app_profile_id, - timeout=attempt_timeout, - metadata=_helpers._make_metadata(self.table_name, self.app_profile_id), - retry=None, - ) - return CrossSync._Sync_Impl.retry_target( - target, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_helpers._retry_exception_factory, - ) - - def bulk_mutate_rows( - self, - mutation_entries: list[RowMutationEntry], - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - ): - """Applies mutations for multiple rows in a single batched request. - - Each individual RowMutationEntry is applied atomically, but separate entries - may be applied in arbitrary order (even for entries targetting the same row) - In total, the row_mutations can contain at most 100000 individual mutations - across all entries - - Idempotent entries (i.e., entries with mutations with explicit timestamps) - will be retried on failure. Non-idempotent will not, and will reported in a - raised exception group - - Args: - mutation_entries: the batches of mutations to apply - Each entry will be applied atomically, but entries will be applied - in arbitrary order - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to the Table's default_mutate_rows_operation_timeout - attempt_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the Table's default_mutate_rows_attempt_timeout. - If None, defaults to operation_timeout. - retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_mutate_rows_retryable_errors - Raises: - MutationsExceptionGroup: if one or more mutations fails - Contains details about any failed entries in .exceptions - ValueError: if invalid arguments are provided""" - operation_timeout, attempt_timeout = _helpers._get_timeouts( - operation_timeout, attempt_timeout, self - ) - retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) - operation = _MutateRowsOperation( - self.client._gapic_client, - self, - mutation_entries, - operation_timeout, - attempt_timeout, - retryable_exceptions=retryable_excs, - ) - operation.start() - - def check_and_mutate_row( - self, - row_key: str | bytes, - predicate: RowFilter | None, - *, - true_case_mutations: Mutation | list[Mutation] | None = None, - false_case_mutations: Mutation | list[Mutation] | None = None, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - ) -> bool: - """Mutates a row atomically based on the output of a predicate filter - - Non-idempotent operation: will not be retried - - Args: - row_key: the key of the row to mutate - predicate: the filter to be applied to the contents of the specified row. - Depending on whether or not any results are yielded, - either true_case_mutations or false_case_mutations will be executed. - If None, checks that the row contains any values at all. - true_case_mutations: - Changes to be atomically applied to the specified row if - predicate yields at least one cell when - applied to row_key. Entries are applied in order, - meaning that earlier mutations can be masked by later - ones. Must contain at least one entry if - false_case_mutations is empty, and at most 100000. - false_case_mutations: - Changes to be atomically applied to the specified row if - predicate_filter does not yield any cells when - applied to row_key. Entries are applied in order, - meaning that earlier mutations can be masked by later - ones. Must contain at least one entry if - `true_case_mutations` is empty, and at most 100000. - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will not be retried. Defaults to the Table's default_operation_timeout - Returns: - bool indicating whether the predicate was true or false - Raises: - google.api_core.exceptions.GoogleAPIError: exceptions from grpc call""" - operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) - if true_case_mutations is not None and ( - not isinstance(true_case_mutations, list) - ): - true_case_mutations = [true_case_mutations] - true_case_list = [m._to_pb() for m in true_case_mutations or []] - if false_case_mutations is not None and ( - not isinstance(false_case_mutations, list) - ): - false_case_mutations = [false_case_mutations] - false_case_list = [m._to_pb() for m in false_case_mutations or []] - metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) - result = self.client._gapic_client.check_and_mutate_row( - true_mutations=true_case_list, - false_mutations=false_case_list, - predicate_filter=predicate._to_pb() if predicate is not None else None, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, - ) - return result.predicate_matched - - def read_modify_write_row( - self, - row_key: str | bytes, - rules: ReadModifyWriteRule | list[ReadModifyWriteRule], - *, - operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - ) -> Row: - """Reads and modifies a row atomically according to input ReadModifyWriteRules, - and returns the contents of all modified cells - - The new value for the timestamp is the greater of the existing timestamp or - the current server time. - - Non-idempotent operation: will not be retried - - Args: - row_key: the key of the row to apply read/modify/write rules to - rules: A rule or set of rules to apply to the row. - Rules are applied in order, meaning that earlier rules will affect the - results of later ones. - operation_timeout: the time budget for the entire operation, in seconds. - Failed requests will not be retried. - Defaults to the Table's default_operation_timeout. - Returns: - Row: a Row containing cell data that was modified as part of the operation - Raises: - google.api_core.exceptions.GoogleAPIError: exceptions from grpc call - ValueError: if invalid arguments are provided""" - operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) - if operation_timeout <= 0: - raise ValueError("operation_timeout must be greater than 0") - if rules is not None and (not isinstance(rules, list)): - rules = [rules] - if not rules: - raise ValueError("rules must contain at least one item") - metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) - result = self.client._gapic_client.read_modify_write_row( - rules=[rule._to_pb() for rule in rules], - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, - ) - return Row._from_pb(result.row) - - def close(self): - """Called to close the Table instance and release any resources held by it.""" - if self._register_instance_future: - self._register_instance_future.cancel() - self.client._remove_instance_registration(self.instance_id, self) - - def __enter__(self): - """Implement async context manager protocol - - Ensure registration task has time to run, so that - grpc channels will be warmed for the specified instance""" - if self._register_instance_future: - self._register_instance_future - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Implement async context manager protocol - - Unregister this instance with the client, so that - grpc channels will no longer be warmed""" - self.close() diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py deleted file mode 100644 index 006982c1f..000000000 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ /dev/null @@ -1,451 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is automatically generated by CrossSync. Do not edit manually. - -# mypy: disable-error-code="unreachable" - -from __future__ import annotations -from typing import Sequence, TYPE_CHECKING -import atexit -import warnings -from collections import deque -import concurrent.futures -from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup -from google.cloud.bigtable.data.exceptions import FailedMutationEntryError -from google.cloud.bigtable.data._helpers import _get_retryable_errors -from google.cloud.bigtable.data._helpers import _get_timeouts -from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data._helpers import _MB_SIZE -from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT -from google.cloud.bigtable.data.mutations import Mutation -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -if CrossSync._Sync_Impl.is_async: - pass -else: - from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation -if TYPE_CHECKING: - from google.cloud.bigtable.data.mutations import RowMutationEntry - - if CrossSync._Sync_Impl.is_async: - pass - else: - from google.cloud.bigtable.data._sync.client import Table - - -class _FlowControl: - """ - Manages flow control for batched mutations. Mutations are registered against - the FlowControl object before being sent, which will block if size or count - limits have reached capacity. As mutations completed, they are removed from - the FlowControl object, which will notify any blocked requests that there - is additional capacity. - - Flow limits are not hard limits. If a single mutation exceeds the configured - limits, it will be allowed as a single batch when the capacity is available. - - Args: - max_mutation_count: maximum number of mutations to send in a single rpc. - This corresponds to individual mutations in a single RowMutationEntry. - max_mutation_bytes: maximum number of bytes to send in a single rpc. - Raises: - ValueError: if max_mutation_count or max_mutation_bytes is less than 0 - """ - - def __init__(self, max_mutation_count: int, max_mutation_bytes: int): - self._max_mutation_count = max_mutation_count - self._max_mutation_bytes = max_mutation_bytes - if self._max_mutation_count < 1: - raise ValueError("max_mutation_count must be greater than 0") - if self._max_mutation_bytes < 1: - raise ValueError("max_mutation_bytes must be greater than 0") - self._capacity_condition = CrossSync._Sync_Impl.Condition() - self._in_flight_mutation_count = 0 - self._in_flight_mutation_bytes = 0 - - def _has_capacity(self, additional_count: int, additional_size: int) -> bool: - """Checks if there is capacity to send a new entry with the given size and count - - FlowControl limits are not hard limits. If a single mutation exceeds - the configured flow limits, it will be sent in a single batch when - previous batches have completed. - - Args: - additional_count: number of mutations in the pending entry - additional_size: size of the pending entry - Returns: - bool: True if there is capacity to send the pending entry, False otherwise - """ - acceptable_size = max(self._max_mutation_bytes, additional_size) - acceptable_count = max(self._max_mutation_count, additional_count) - new_size = self._in_flight_mutation_bytes + additional_size - new_count = self._in_flight_mutation_count + additional_count - return new_size <= acceptable_size and new_count <= acceptable_count - - def remove_from_flow( - self, mutations: RowMutationEntry | list[RowMutationEntry] - ) -> None: - """Removes mutations from flow control. This method should be called once - for each mutation that was sent to add_to_flow, after the corresponding - operation is complete. - - Args: - mutations: mutation or list of mutations to remove from flow control""" - if not isinstance(mutations, list): - mutations = [mutations] - total_count = sum((len(entry.mutations) for entry in mutations)) - total_size = sum((entry.size() for entry in mutations)) - self._in_flight_mutation_count -= total_count - self._in_flight_mutation_bytes -= total_size - with self._capacity_condition: - self._capacity_condition.notify_all() - - def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): - """Generator function that registers mutations with flow control. As mutations - are accepted into the flow control, they are yielded back to the caller, - to be sent in a batch. If the flow control is at capacity, the generator - will block until there is capacity available. - - Args: - mutations: list mutations to break up into batches - Yields: - list[RowMutationEntry]: - list of mutations that have reserved space in the flow control. - Each batch contains at least one mutation.""" - if not isinstance(mutations, list): - mutations = [mutations] - start_idx = 0 - end_idx = 0 - while end_idx < len(mutations): - start_idx = end_idx - batch_mutation_count = 0 - with self._capacity_condition: - while end_idx < len(mutations): - next_entry = mutations[end_idx] - next_size = next_entry.size() - next_count = len(next_entry.mutations) - if ( - self._has_capacity(next_count, next_size) - and batch_mutation_count + next_count - <= _MUTATE_ROWS_REQUEST_MUTATION_LIMIT - ): - end_idx += 1 - batch_mutation_count += next_count - self._in_flight_mutation_bytes += next_size - self._in_flight_mutation_count += next_count - elif start_idx != end_idx: - break - else: - self._capacity_condition.wait_for( - lambda: self._has_capacity(next_count, next_size) - ) - yield mutations[start_idx:end_idx] - - -class MutationsBatcher: - """ - Allows users to send batches using context manager API: - - Runs mutate_row, mutate_rows, and check_and_mutate_row internally, combining - to use as few network requests as required - - Will automatically flush the batcher: - - every flush_interval seconds - - after queue size reaches flush_limit_mutation_count - - after queue reaches flush_limit_bytes - - when batcher is closed or destroyed - - Args: - table: Table to preform rpc calls - flush_interval: Automatically flush every flush_interval seconds. - If None, no time-based flushing is performed. - flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count - mutations are added across all entries. If None, this limit is ignored. - flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. - flow_control_max_mutation_count: Maximum number of inflight mutations. - flow_control_max_bytes: Maximum number of inflight bytes. - batch_operation_timeout: timeout for each mutate_rows operation, in seconds. - If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_operation_timeout. - batch_attempt_timeout: timeout for each individual request, in seconds. - If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_attempt_timeout. - If None, defaults to batch_operation_timeout. - batch_retryable_errors: a list of errors that will be retried if encountered. - Defaults to the Table's default_mutate_rows_retryable_errors. - """ - - def __init__( - self, - table: Table, - *, - flush_interval: float | None = 5, - flush_limit_mutation_count: int | None = 1000, - flush_limit_bytes: int = 20 * _MB_SIZE, - flow_control_max_mutation_count: int = 100000, - flow_control_max_bytes: int = 100 * _MB_SIZE, - batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - batch_retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - ): - self._operation_timeout, self._attempt_timeout = _get_timeouts( - batch_operation_timeout, batch_attempt_timeout, table - ) - self._retryable_errors: list[type[Exception]] = _get_retryable_errors( - batch_retryable_errors, table - ) - self._closed = CrossSync._Sync_Impl.Event() - self._table = table - self._staged_entries: list[RowMutationEntry] = [] - self._staged_count, self._staged_bytes = (0, 0) - self._flow_control = _FlowControl( - flow_control_max_mutation_count, flow_control_max_bytes - ) - self._flush_limit_bytes = flush_limit_bytes - self._flush_limit_count = ( - flush_limit_mutation_count - if flush_limit_mutation_count is not None - else float("inf") - ) - self._sync_executor = ( - concurrent.futures.ThreadPoolExecutor(max_workers=8) - if not CrossSync._Sync_Impl.is_async - else None - ) - self._flush_timer = CrossSync._Sync_Impl.create_task( - self._timer_routine, flush_interval, sync_executor=self._sync_executor - ) - self._flush_jobs: set[CrossSync._Sync_Impl.Future[None]] = set() - self._entries_processed_since_last_raise: int = 0 - self._exceptions_since_last_raise: int = 0 - self._exception_list_limit: int = 10 - self._oldest_exceptions: list[Exception] = [] - self._newest_exceptions: deque[Exception] = deque( - maxlen=self._exception_list_limit - ) - atexit.register(self._on_exit) - - def _timer_routine(self, interval: float | None) -> None: - """Set up a background task to flush the batcher every interval seconds - - If interval is None, an empty future is returned - - Args: - flush_interval: Automatically flush every flush_interval seconds. - If None, no time-based flushing is performed.""" - if not interval or interval <= 0: - return None - while not self._closed.is_set(): - CrossSync._Sync_Impl.event_wait( - self._closed, timeout=interval, async_break_early=False - ) - if not self._closed.is_set() and self._staged_entries: - self._schedule_flush() - - def append(self, mutation_entry: RowMutationEntry): - """Add a new set of mutations to the internal queue - - Args: - mutation_entry: new entry to add to flush queue - Raises: - RuntimeError: if batcher is closed - ValueError: if an invalid mutation type is added""" - if self._closed.is_set(): - raise RuntimeError("Cannot append to closed MutationsBatcher") - if isinstance(mutation_entry, Mutation): - raise ValueError( - f"invalid mutation type: {type(mutation_entry).__name__}. Only RowMutationEntry objects are supported by batcher" - ) - self._staged_entries.append(mutation_entry) - self._staged_count += len(mutation_entry.mutations) - self._staged_bytes += mutation_entry.size() - if ( - self._staged_count >= self._flush_limit_count - or self._staged_bytes >= self._flush_limit_bytes - ): - self._schedule_flush() - CrossSync._Sync_Impl.yield_to_event_loop() - - def _schedule_flush(self) -> CrossSync._Sync_Impl.Future[None] | None: - """Update the flush task to include the latest staged entries - - Returns: - Future[None] | None: - future representing the background task, if started""" - if self._staged_entries: - entries, self._staged_entries = (self._staged_entries, []) - self._staged_count, self._staged_bytes = (0, 0) - new_task = CrossSync._Sync_Impl.create_task( - self._flush_internal, entries, sync_executor=self._sync_executor - ) - if not new_task.done(): - self._flush_jobs.add(new_task) - new_task.add_done_callback(self._flush_jobs.remove) - return new_task - return None - - def _flush_internal(self, new_entries: list[RowMutationEntry]): - """Flushes a set of mutations to the server, and updates internal state - - Args: - new_entries list of RowMutationEntry objects to flush""" - in_process_requests: list[ - CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] - ] = [] - for batch in self._flow_control.add_to_flow(new_entries): - batch_task = CrossSync._Sync_Impl.create_task( - self._execute_mutate_rows, batch, sync_executor=self._sync_executor - ) - in_process_requests.append(batch_task) - found_exceptions = self._wait_for_batch_results(*in_process_requests) - self._entries_processed_since_last_raise += len(new_entries) - self._add_exceptions(found_exceptions) - - def _execute_mutate_rows( - self, batch: list[RowMutationEntry] - ) -> list[FailedMutationEntryError]: - """Helper to execute mutation operation on a batch - - Args: - batch: list of RowMutationEntry objects to send to server - timeout: timeout in seconds. Used as operation_timeout and attempt_timeout. - If not given, will use table defaults - Returns: - list[FailedMutationEntryError]: - list of FailedMutationEntryError objects for mutations that failed. - FailedMutationEntryError objects will not contain index information""" - try: - operation = _MutateRowsOperation( - self._table.client._gapic_client, - self._table, - batch, - operation_timeout=self._operation_timeout, - attempt_timeout=self._attempt_timeout, - retryable_exceptions=self._retryable_errors, - ) - operation.start() - except MutationsExceptionGroup as e: - for subexc in e.exceptions: - subexc.index = None - return list(e.exceptions) - finally: - self._flow_control.remove_from_flow(batch) - return [] - - def _add_exceptions(self, excs: list[Exception]): - """Add new list of exceptions to internal store. To avoid unbounded memory, - the batcher will store the first and last _exception_list_limit exceptions, - and discard any in between. - - Args: - excs: list of exceptions to add to the internal store""" - self._exceptions_since_last_raise += len(excs) - if excs and len(self._oldest_exceptions) < self._exception_list_limit: - addition_count = self._exception_list_limit - len(self._oldest_exceptions) - self._oldest_exceptions.extend(excs[:addition_count]) - excs = excs[addition_count:] - if excs: - self._newest_exceptions.extend(excs[-self._exception_list_limit :]) - - def _raise_exceptions(self): - """Raise any unreported exceptions from background flush operations - - Raises: - MutationsExceptionGroup: exception group with all unreported exceptions""" - if self._oldest_exceptions or self._newest_exceptions: - oldest, self._oldest_exceptions = (self._oldest_exceptions, []) - newest = list(self._newest_exceptions) - self._newest_exceptions.clear() - entry_count, self._entries_processed_since_last_raise = ( - self._entries_processed_since_last_raise, - 0, - ) - exc_count, self._exceptions_since_last_raise = ( - self._exceptions_since_last_raise, - 0, - ) - raise MutationsExceptionGroup.from_truncated_lists( - first_list=oldest, - last_list=newest, - total_excs=exc_count, - entry_count=entry_count, - ) - - def __enter__(self): - """Allow use of context manager API""" - return self - - def __exit__(self, exc_type, exc, tb): - """Allow use of context manager API. - - Flushes the batcher and cleans up resources.""" - self.close() - - @property - def closed(self) -> bool: - """Returns: - - True if the batcher is closed, False otherwise""" - return self._closed.is_set() - - def close(self): - """Flush queue and clean up resources""" - self._closed.set() - self._flush_timer.cancel() - self._schedule_flush() - CrossSync._Sync_Impl.wait([*self._flush_jobs, self._flush_timer]) - if self._sync_executor: - with self._sync_executor: - self._sync_executor.shutdown(wait=True) - atexit.unregister(self._on_exit) - self._raise_exceptions() - - def _on_exit(self): - """Called when program is exited. Raises warning if unflushed mutations remain""" - if not self._closed.is_set() and self._staged_entries: - warnings.warn( - f"MutationsBatcher for table {self._table.table_name} was not closed. {len(self._staged_entries)} Unflushed mutations will not be sent to the server." - ) - - @staticmethod - def _wait_for_batch_results( - *tasks: CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] - | CrossSync._Sync_Impl.Future[None], - ) -> list[Exception]: - """Takes in a list of futures representing _execute_mutate_rows tasks, - waits for them to complete, and returns a list of errors encountered. - - Args: - *tasks: futures representing _execute_mutate_rows or _flush_internal tasks - Returns: - list[Exception]: - list of Exceptions encountered by any of the tasks. Errors are expected - to be FailedMutationEntryError, representing a failed mutation operation. - If a task fails with a different exception, it will be included in the - output list. Successful tasks will not be represented in the output list. - """ - if not tasks: - return [] - exceptions: list[Exception] = [] - for task in tasks: - if CrossSync._Sync_Impl.is_async: - task - try: - exc_list = task.result() - if exc_list: - for exc in exc_list: - exc.index = None - exceptions.extend(exc_list) - except Exception as e: - exceptions.append(e) - return exceptions diff --git a/google/cloud/bigtable/data/_sync/sync_gen.yaml b/google/cloud/bigtable/data/_sync/sync_gen.yaml deleted file mode 100644 index aa5282e69..000000000 --- a/google/cloud/bigtable/data/_sync/sync_gen.yaml +++ /dev/null @@ -1,62 +0,0 @@ -asyncio_replacements: # Replace asyncio functionaility - sleep: time.sleep - Queue: queue.Queue - Condition: threading.Condition - Future: concurrent.futures.Future - Task: concurrent.futures.Future - Event: threading.Event - -text_replacements: # Find and replace specific text patterns - __anext__: __next__ - __aiter__: __iter__ - __aenter__: __enter__ - __aexit__: __exit__ - aclose: close - AsyncIterable: Iterable - AsyncIterator: Iterator - StopAsyncIteration: StopIteration - Awaitable: None - BigtableAsyncClient: BigtableClient - retry_target_async: retry_target - retry_target_stream_async: retry_target_stream - -added_imports: - - "from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient" - - "from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport, PooledChannel" - - "from typing import Generator, Iterable, Iterator" - - "from grpc import Channel" - - "import google.cloud.bigtable.data.exceptions as bt_exceptions" - - "import threading" - - "import concurrent.futures" - -classes: # Specify transformations for individual classes - - path: google.cloud.bigtable.data._async._read_rows._ReadRowsOperationAsync - autogen_sync_name: _ReadRowsOperation_SyncGen - concrete_path: google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation - - path: google.cloud.bigtable.data._async._mutate_rows._MutateRowsOperationAsync - autogen_sync_name: _MutateRowsOperation_SyncGen - concrete_path: google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation - - path: google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync - autogen_sync_name: MutationsBatcher_SyncGen - concrete_path: google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher - error_methods: ["_create_bg_task", "close", "_wait_for_batch_results", "_timer_routine"] - - path: google.cloud.bigtable.data._async.mutations_batcher._FlowControlAsync - autogen_sync_name: _FlowControl_SyncGen - concrete_path: google.cloud.bigtable.data._sync.mutations_batcher._FlowControl - - path: google.cloud.bigtable.data._async.client.BigtableDataClientAsync - autogen_sync_name: BigtableDataClient_SyncGen - concrete_path: google.cloud.bigtable.data._sync.client.BigtableDataClient - drop_methods: ["close"] - error_methods: ["_start_background_channel_refresh", "_client_version", "_execute_ping_and_warms"] - asyncio_replacements: - sleep: self._is_closed.wait - text_replacements: - PooledBigtableGrpcAsyncIOTransport: PooledBigtableGrpcTransport - AsyncChannel: Channel - AsyncPooledChannel: PooledChannel - - path: google.cloud.bigtable.data._async.client.TableAsync - autogen_sync_name: Table_SyncGen - concrete_path: google.cloud.bigtable.data._sync.client.Table - error_methods: ["_register_with_client", "_shard_batch_helper", "__aenter__"] - -save_path: "google/cloud/bigtable/data/_sync/_autogen.py" diff --git a/google/cloud/bigtable/data/_sync/system_tests.yaml b/google/cloud/bigtable/data/_sync/system_tests.yaml deleted file mode 100644 index 43c78aa9e..000000000 --- a/google/cloud/bigtable/data/_sync/system_tests.yaml +++ /dev/null @@ -1,25 +0,0 @@ -asyncio_replacements: # Replace asyncio functionaility - sleep: time.sleep - -added_imports: - - "from .test_system_async import TEST_FAMILY, TEST_FAMILY_2" - - "from google.cloud.bigtable.data import BigtableDataClient" - - "import time" - -text_replacements: - pytest_asyncio: pytest - AsyncRetry: Retry - BigtableDataClientAsync: BigtableDataClient - TempRowBuilderAsync: TempRowBuilder - StopAsyncIteration: StopIteration - __anext__: __next__ - aclose: close - -classes: - - path: tests.system.data.test_system_async.TempRowBuilderAsync - autogen_sync_name: TempRowBuilder - - path: tests.system.data.test_system_async.TestSystemAsync - autogen_sync_name: TestSystemSync - drop_methods: ["event_loop"] - -save_path: "tests/system/data/test_system.py" diff --git a/google/cloud/bigtable/data/_sync/transformers.py b/google/cloud/bigtable/data/_sync/transformers.py deleted file mode 100644 index 60498763d..000000000 --- a/google/cloud/bigtable/data/_sync/transformers.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import ast - -from dataclasses import dataclass, field -from .cross_sync import CrossSync - - -class SymbolReplacer(ast.NodeTransformer): - """ - Replaces all instances of a symbol in an AST with a replacement - - Works for function signatures, method calls, docstrings, and type annotations - """ - def __init__(self, replacements: dict[str, str]): - self.replacements = replacements - - def visit_Name(self, node): - if node.id in self.replacements: - node.id = self.replacements[node.id] - return node - - def visit_Attribute(self, node): - return ast.copy_location( - ast.Attribute( - self.visit(node.value), - self.replacements.get(node.attr, node.attr), - node.ctx, - ), - node, - ) - - def visit_AsyncFunctionDef(self, node): - """ - Replace async function docstrings - """ - # use same logic as FunctionDef - return self.visit_FunctionDef(node) - - def visit_FunctionDef(self, node): - """ - Replace function docstrings - """ - docstring = ast.get_docstring(node) - if docstring and isinstance(node.body[0], ast.Expr) and isinstance( - node.body[0].value, ast.Str - ): - for key_word, replacement in self.replacements.items(): - docstring = docstring.replace(f" {key_word} ", f" {replacement} ") - node.body[0].value.s = docstring - return self.generic_visit(node) - - def visit_Str(self, node): - """Replace string type annotations""" - node.s = self.replacements.get(node.s, node.s) - return node - - -class AsyncToSync(ast.NodeTransformer): - """ - Replaces or strips all async keywords from a given AST - """ - def visit_Await(self, node): - """ - Strips await keyword - """ - return self.visit(node.value) - - def visit_AsyncFor(self, node): - """ - Replaces `async for` with `for` - """ - return ast.copy_location( - ast.For( - self.visit(node.target), - self.visit(node.iter), - [self.visit(stmt) for stmt in node.body], - [self.visit(stmt) for stmt in node.orelse], - ), - node, - ) - - def visit_AsyncWith(self, node): - """ - Replaces `async with` with `with` - """ - return ast.copy_location( - ast.With( - [self.visit(item) for item in node.items], - [self.visit(stmt) for stmt in node.body], - ), - node, - ) - - def visit_AsyncFunctionDef(self, node): - """ - Replaces `async def` with `def` - """ - return ast.copy_location( - ast.FunctionDef( - node.name, - self.visit(node.args), - [self.visit(stmt) for stmt in node.body], - [self.visit(decorator) for decorator in node.decorator_list], - node.returns and self.visit(node.returns), - ), - node, - ) - - def visit_ListComp(self, node): - """ - Replaces `async for` with `for` in list comprehensions - """ - for generator in node.generators: - generator.is_async = False - return self.generic_visit(node) - - -class CrossSyncMethodDecoratorHandler(ast.NodeTransformer): - """ - Visits each method in a class, and handles any CrossSync decorators found - """ - - def visit_FunctionDef(self, node): - return self.visit_AsyncFunctionDef(node) - - def visit_AsyncFunctionDef(self, node): - try: - if hasattr(node, "decorator_list"): - found_list, node.decorator_list = node.decorator_list, [] - for decorator in found_list: - if decorator == CrossSync.convert: - # convert async to sync - kwargs = CrossSync.convert.parse_ast_keywords(decorator) - node = AsyncToSync().visit(node) - # replace method name if specified - if kwargs["sync_name"] is not None: - node.name = kwargs["sync_name"] - # replace symbols if specified - if kwargs["replace_symbols"]: - node = SymbolReplacer(kwargs["replace_symbols"]).visit(node) - elif decorator == CrossSync.drop_method: - # drop method entirely from class - return None - elif decorator == CrossSync.pytest: - # also convert pytest methods to sync - node = AsyncToSync().visit(node) - elif decorator == CrossSync.pytest_fixture: - # add pytest.fixture decorator - decorator.func.value = ast.Name(id="pytest", ctx=ast.Load()) - decorator.func.attr = "fixture" - node.decorator_list.append(decorator) - else: - # keep unknown decorators - node.decorator_list.append(decorator) - return node - except ValueError as e: - raise ValueError(f"node {node.name} failed") from e - - -@dataclass -class CrossSyncFileArtifact: - """ - Used to track an output file location. Collects a number of converted classes, and then - writes them to disk - """ - - file_path: str - imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = field( - default_factory=list - ) - converted_classes: list[ast.ClassDef] = field(default_factory=list) - contained_classes: set[str] = field(default_factory=set) - mypy_ignore: list[str] = field(default_factory=list) - - def __hash__(self): - return hash(self.file_path) - - def __repr__(self): - return f"CrossSyncFileArtifact({self.file_path}, classes={[c.name for c in self.converted_classes]})" - - def render(self, with_black=True, save_to_disk=False) -> str: - full_str = ( - "# Copyright 2024 Google LLC\n" - "#\n" - '# Licensed under the Apache License, Version 2.0 (the "License");\n' - "# you may not use this file except in compliance with the License.\n" - "# You may obtain a copy of the License at\n" - "#\n" - "# http://www.apache.org/licenses/LICENSE-2.0\n" - "#\n" - "# Unless required by applicable law or agreed to in writing, software\n" - '# distributed under the License is distributed on an "AS IS" BASIS,\n' - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" - "# See the License for the specific language governing permissions and\n" - "# limitations under the License.\n" - "#\n" - "# This file is automatically generated by CrossSync. Do not edit manually.\n" - ) - if self.mypy_ignore: - full_str += ( - f'\n# mypy: disable-error-code="{",".join(self.mypy_ignore)}"\n\n' - ) - full_str += "\n".join([ast.unparse(node) for node in self.imports]) # type: ignore - full_str += "\n\n" - full_str += "\n".join([ast.unparse(node) for node in self.converted_classes]) # type: ignore - if with_black: - import black # type: ignore - import autoflake # type: ignore - - full_str = black.format_str( - autoflake.fix_code(full_str, remove_all_unused_imports=True), - mode=black.FileMode(), - ) - if save_to_disk: - with open(self.file_path, "w") as f: - f.write(full_str) - return full_str - - -class CrossSyncClassDecoratorHandler(ast.NodeTransformer): - """ - Visits each class in the file, and if it has a CrossSync decorator, it will be transformed. - - Uses CrossSyncMethodDecoratorHandler to visit and (potentially) convert each method in the class - """ - def __init__(self, file_path): - self.in_path = file_path - self._artifact_dict: dict[str, CrossSyncFileArtifact] = {} - self.imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = [] - self.cross_sync_symbol_transformer = SymbolReplacer( - {"CrossSync": "CrossSync._Sync_Impl"} - ) - self.cross_sync_method_handler = CrossSyncMethodDecoratorHandler() - - def convert_file( - self, artifacts: set[CrossSyncFileArtifact] | None = None - ) -> set[CrossSyncFileArtifact]: - """ - Called to run a file through the transformer. If any classes are marked with a CrossSync decorator, - they will be transformed and added to an artifact for the output file - """ - tree = ast.parse(open(self.in_path).read()) - self._artifact_dict = {f.file_path: f for f in artifacts or []} - self.imports = self._get_imports(tree) - self.visit(tree) - found = set(self._artifact_dict.values()) - if artifacts is not None: - artifacts.update(found) - return found - - def visit_ClassDef(self, node): - """ - Called for each class in file. If class has a CrossSync decorator, it will be transformed - according to the decorator arguments - """ - try: - for decorator in node.decorator_list: - if decorator == CrossSync.export_sync: - kwargs = CrossSync.export_sync.parse_ast_keywords(decorator) - # find the path to write the sync class to - sync_path = kwargs["path"] - out_file = "/".join(sync_path.rsplit(".")[:-1]) + ".py" - sync_cls_name = sync_path.rsplit(".", 1)[-1] - # find the artifact file for the save location - output_artifact = self._artifact_dict.get( - out_file, CrossSyncFileArtifact(out_file) - ) - # write converted class details if not already present - if sync_cls_name not in output_artifact.contained_classes: - converted = self._transform_class(node, sync_cls_name, **kwargs) - output_artifact.converted_classes.append(converted) - # handle file-level mypy ignores - mypy_ignores = [ - s - for s in kwargs["mypy_ignore"] - if s not in output_artifact.mypy_ignore - ] - output_artifact.mypy_ignore.extend(mypy_ignores) - # handle file-level imports - if not output_artifact.imports and kwargs["include_file_imports"]: - output_artifact.imports = self.imports - self._artifact_dict[out_file] = output_artifact - return node - except ValueError as e: - raise ValueError(f"failed for class: {node.name}") from e - - def _transform_class( - self, - cls_ast: ast.ClassDef, - new_name: str, - replace_symbols: dict[str, str] | None = None, - **kwargs, - ) -> ast.ClassDef: - """ - Transform async class into sync one, by running through a series of transformers - """ - # update name - cls_ast.name = new_name - # strip CrossSync decorators - if hasattr(cls_ast, "decorator_list"): - cls_ast.decorator_list = [ - d for d in cls_ast.decorator_list if "CrossSync" not in ast.dump(d) - ] - # convert class contents - cls_ast = self.cross_sync_symbol_transformer.visit(cls_ast) - if replace_symbols: - cls_ast = SymbolReplacer(replace_symbols).visit(cls_ast) - cls_ast = self.cross_sync_method_handler.visit(cls_ast) - return cls_ast - - def _get_imports( - self, tree: ast.Module - ) -> list[ast.Import | ast.ImportFrom | ast.Try | ast.If]: - """ - Grab the imports from the top of the file - """ - imports = [] - for node in tree.body: - if isinstance(node, (ast.Import, ast.ImportFrom, ast.Try, ast.If)): - imports.append(self.cross_sync_symbol_transformer.visit(node)) - return imports - - diff --git a/google/cloud/bigtable/data/_sync/unit_tests.yaml b/google/cloud/bigtable/data/_sync/unit_tests.yaml deleted file mode 100644 index 6a4cb2159..000000000 --- a/google/cloud/bigtable/data/_sync/unit_tests.yaml +++ /dev/null @@ -1,116 +0,0 @@ -asyncio_replacements: # Replace entire modules - sleep: time.sleep - Queue: queue.Queue - Condition: threading.Condition - Future: concurrent.futures.Future - create_task: threading.Thread - -added_imports: - - "import google.api_core.exceptions as core_exceptions" - - "import threading" - - "import concurrent.futures" - - "from google.cloud.bigtable.data import Table" - -text_replacements: # Find and replace specific text patterns - __anext__: __next__ - __aiter__: __iter__ - __aenter__: __enter__ - __aexit__: __exit__ - aclose: close - AsyncIterable: Iterable - AsyncIterator: Iterator - StopAsyncIteration: StopIteration - Awaitable: None - BigtableDataClientAsync: BigtableDataClient - BigtableAsyncClient: BigtableClient - TableAsync: Table - AsyncMock: mock.Mock - retry_target_async: retry_target - TestBigtableDataClientAsync: TestBigtableDataClient - TestReadRowsAsync: TestReadRows - assert_awaited_once: assert_called_once - assert_awaited: assert_called_once - grpc_helpers_async: grpc_helpers - -classes: - - path: tests.unit.data._async.test__mutate_rows.TestMutateRowsOperation - autogen_sync_name: TestMutateRowsOperation - replace_methods: - _target_class: | - from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation - return _MutateRowsOperation - - path: tests.unit.data._async.test__read_rows.TestReadRowsOperation - autogen_sync_name: TestReadRowsOperation - text_replacements: - test_aclose: test_close - replace_methods: - _get_target_class: | - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - return _ReadRowsOperation - - path: tests.unit.data._async.test_mutations_batcher.Test_FlowControl - autogen_sync_name: Test_FlowControl - replace_methods: - _target_class: | - from google.cloud.bigtable.data._sync.mutations_batcher import _FlowControl - return _FlowControl - - path: tests.unit.data._async.test_mutations_batcher.TestMutationsBatcherAsync - autogen_sync_name: TestMutationsBatcher - replace_methods: - _get_target_class: | - from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher - return MutationsBatcher - is_async: "return False" - - path: tests.unit.data._async.test_client.TestBigtableDataClientAsync - autogen_sync_name: TestBigtableDataClient - added_imports: - - "from google.api_core import grpc_helpers" - - "from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient" - - "from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport" - - "from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledChannel" - text_replacements: - PooledBigtableGrpcAsyncIOTransport: PooledBigtableGrpcTransport - PooledChannelAsync: PooledChannel - TestTableAsync: TestTable - replace_methods: - _get_target_class: | - from google.cloud.bigtable.data._sync.client import BigtableDataClient - return BigtableDataClient - is_async: "return False" - drop_methods: ["test_client_ctor_sync", "test__start_background_channel_refresh_sync", "test__start_background_channel_refresh_tasks_names", "test_close_with_timeout"] - - path: tests.unit.data._async.test_client.TestTableAsync - autogen_sync_name: TestTable - replace_methods: - _get_target_class: | - from google.cloud.bigtable.data._sync.client import Table - return Table - is_async: "return False" - drop_methods: ["test_table_ctor_sync"] - - path: tests.unit.data._async.test_client.TestReadRowsAsync - autogen_sync_name: TestReadRows - replace_methods: - _get_operation_class: | - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - return _ReadRowsOperation - - path: tests.unit.data._async.test_client.TestReadRowsShardedAsync - autogen_sync_name: TestReadRowsSharded - - path: tests.unit.data._async.test_client.TestSampleRowKeysAsync - autogen_sync_name: TestSampleRowKeys - - path: tests.unit.data._async.test_client.TestMutateRowAsync - autogen_sync_name: TestMutateRow - - path: tests.unit.data._async.test_client.TestBulkMutateRowsAsync - autogen_sync_name: TestBulkMutateRows - - path: tests.unit.data._async.test_client.TestCheckAndMutateRowAsync - autogen_sync_name: TestCheckAndMutateRow - - path: tests.unit.data._async.test_client.TestReadModifyWriteRowAsync - autogen_sync_name: TestReadModifyWriteRow - - path: tests.unit.data._async.test_read_rows_acceptance.TestReadRowsAcceptanceAsync - autogen_sync_name: TestReadRowsAcceptance - replace_methods: - _get_operation_class: | - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - return _ReadRowsOperation - _get_client_class: | - from google.cloud.bigtable.data._sync.client import BigtableDataClient - return BigtableDataClient - -save_path: "tests/unit/data/_sync/test_autogen.py" diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py deleted file mode 100644 index e7330bf57..000000000 --- a/tests/system/data/test_system.py +++ /dev/null @@ -1,812 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is automatically generated by CrossSync. Do not edit manually. -import pytest -import uuid -import os -from google.api_core import retry -from google.api_core.exceptions import ClientError -from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE -from google.cloud.environment_vars import BIGTABLE_EMULATOR -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -if CrossSync._Sync_Impl.is_async: - pass -else: - from google.cloud.bigtable.data._sync.client import BigtableDataClient - from .test_system_async import TEST_FAMILY, TEST_FAMILY_2 - - -class TempRowBuilder: - """ - Used to add rows to a table for testing purposes. - """ - - def __init__(self, table): - self.rows = [] - self.table = table - - def add_row( - self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" - ): - if isinstance(value, str): - value = value.encode("utf-8") - elif isinstance(value, int): - value = value.to_bytes(8, byteorder="big", signed=True) - request = { - "table_name": self.table.table_name, - "row_key": row_key, - "mutations": [ - { - "set_cell": { - "family_name": family, - "column_qualifier": qualifier, - "value": value, - } - } - ], - } - self.table.client._gapic_client.mutate_row(request) - self.rows.append(row_key) - - def delete_rows(self): - if self.rows: - request = { - "table_name": self.table.table_name, - "entries": [ - {"row_key": row, "mutations": [{"delete_from_row": {}}]} - for row in self.rows - ], - } - self.table.client._gapic_client.mutate_rows(request) - - -class TestSystem: - @pytest.fixture(scope="session") - def client(self): - project = os.getenv("GOOGLE_CLOUD_PROJECT") or None - with BigtableDataClient(project=project, pool_size=4) as client: - yield client - - @pytest.fixture(scope="session") - def table(self, client, table_id, instance_id): - with client.get_table(instance_id, table_id) as table: - yield table - - @pytest.fixture(scope="session") - def column_family_config(self): - """specify column families to create when creating a new test table""" - from google.cloud.bigtable_admin_v2 import types - - return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} - - @pytest.fixture(scope="session") - def init_table_id(self): - """The table_id to use when creating a new test table""" - return f"test-table-{uuid.uuid4().hex}" - - @pytest.fixture(scope="session") - def cluster_config(self, project_id): - """Configuration for the clusters to use when creating a new instance""" - from google.cloud.bigtable_admin_v2 import types - - cluster = { - "test-cluster": types.Cluster( - location=f"projects/{project_id}/locations/us-central1-b", serve_nodes=1 - ) - } - return cluster - - @pytest.mark.usefixtures("table") - def _retrieve_cell_value(self, table, row_key): - """Helper to read an individual row""" - from google.cloud.bigtable.data import ReadRowsQuery - - row_list = table.read_rows(ReadRowsQuery(row_keys=row_key)) - assert len(row_list) == 1 - row = row_list[0] - cell = row.cells[0] - return cell.value - - def _create_row_and_mutation( - self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" - ): - """Helper to create a new row, and a sample set_cell mutation to change its value""" - from google.cloud.bigtable.data.mutations import SetCell - - row_key = uuid.uuid4().hex.encode() - family = TEST_FAMILY - qualifier = b"test-qualifier" - temp_rows.add_row( - row_key, family=family, qualifier=qualifier, value=start_value - ) - assert self._retrieve_cell_value(table, row_key) == start_value - mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) - return (row_key, mutation) - - @pytest.fixture(scope="function") - def temp_rows(self, table): - builder = TempRowBuilder(table) - yield builder - builder.delete_rows() - - @pytest.mark.usefixtures("table") - @pytest.mark.usefixtures("client") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=10 - ) - def test_ping_and_warm_gapic(self, client, table): - """Simple ping rpc test - This test ensures channels are able to authenticate with backend""" - request = {"name": table.instance_name} - client._gapic_client.ping_and_warm(request) - - @pytest.mark.usefixtures("table") - @pytest.mark.usefixtures("client") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_ping_and_warm(self, client, table): - """Test ping and warm from handwritten client""" - try: - channel = client.transport._grpc_channel.pool[0] - except Exception: - channel = client.transport._grpc_channel - results = client._ping_and_warm_instances(channel) - assert len(results) == 1 - assert results[0] is None - - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_mutation_set_cell(self, table, temp_rows): - """Ensure cells can be set properly""" - row_key = b"bulk_mutate" - new_value = uuid.uuid4().hex.encode() - row_key, mutation = self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - table.mutate_row(row_key, mutation) - assert self._retrieve_cell_value(table, row_key) == new_value - - @pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" - ) - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_sample_row_keys(self, client, table, temp_rows, column_split_config): - """Sample keys should return a single sample in small test tables""" - temp_rows.add_row(b"row_key_1") - temp_rows.add_row(b"row_key_2") - results = table.sample_row_keys() - assert len(results) == len(column_split_config) + 1 - for idx in range(len(column_split_config)): - assert results[idx][0] == column_split_config[idx] - assert isinstance(results[idx][1], int) - assert results[-1][0] == b"" - assert isinstance(results[-1][1], int) - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - def test_bulk_mutations_set_cell(self, client, table, temp_rows): - """Ensure cells can be set properly""" - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - row_key, mutation = self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - table.bulk_mutate_rows([bulk_mutation]) - assert self._retrieve_cell_value(table, row_key) == new_value - - def test_bulk_mutations_raise_exception(self, client, table): - """If an invalid mutation is passed, an exception should be raised""" - from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell - from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup - from google.cloud.bigtable.data.exceptions import FailedMutationEntryError - - row_key = uuid.uuid4().hex.encode() - mutation = SetCell( - family="nonexistent", qualifier=b"test-qualifier", new_value=b"" - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - with pytest.raises(MutationsExceptionGroup) as exc: - table.bulk_mutate_rows([bulk_mutation]) - assert len(exc.value.exceptions) == 1 - entry_error = exc.value.exceptions[0] - assert isinstance(entry_error, FailedMutationEntryError) - assert entry_error.index == 0 - assert entry_error.entry == bulk_mutation - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_mutations_batcher_context_manager(self, client, table, temp_rows): - """test batcher with context manager. Should flush on exit""" - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - row_key2, mutation2 = self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - with table.mutations_batcher() as batcher: - batcher.append(bulk_mutation) - batcher.append(bulk_mutation2) - assert self._retrieve_cell_value(table, row_key) == new_value - assert len(batcher._staged_entries) == 0 - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_mutations_batcher_timer_flush(self, client, table, temp_rows): - """batch should occur after flush_interval seconds""" - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - row_key, mutation = self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - flush_interval = 0.1 - with table.mutations_batcher(flush_interval=flush_interval) as batcher: - batcher.append(bulk_mutation) - CrossSync._Sync_Impl.yield_to_event_loop() - assert len(batcher._staged_entries) == 1 - CrossSync._Sync_Impl.sleep(flush_interval + 0.1) - assert len(batcher._staged_entries) == 0 - assert self._retrieve_cell_value(table, row_key) == new_value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_mutations_batcher_count_flush(self, client, table, temp_rows): - """batch should flush after flush_limit_mutation_count mutations""" - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: - batcher.append(bulk_mutation) - assert len(batcher._flush_jobs) == 0 - assert len(batcher._staged_entries) == 1 - batcher.append(bulk_mutation2) - assert len(batcher._flush_jobs) == 1 - for future in list(batcher._flush_jobs): - future - future.result() - assert len(batcher._staged_entries) == 0 - assert len(batcher._flush_jobs) == 0 - assert self._retrieve_cell_value(table, row_key) == new_value - assert self._retrieve_cell_value(table, row_key2) == new_value2 - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): - """batch should flush after flush_limit_bytes bytes""" - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 - with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: - batcher.append(bulk_mutation) - assert len(batcher._flush_jobs) == 0 - assert len(batcher._staged_entries) == 1 - batcher.append(bulk_mutation2) - assert len(batcher._flush_jobs) == 1 - assert len(batcher._staged_entries) == 0 - for future in list(batcher._flush_jobs): - future - future.result() - assert self._retrieve_cell_value(table, row_key) == new_value - assert self._retrieve_cell_value(table, row_key2) == new_value2 - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - def test_mutations_batcher_no_flush(self, client, table, temp_rows): - """test with no flush requirements met""" - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - start_value = b"unchanged" - row_key, mutation = self._create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = self._create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 - with table.mutations_batcher( - flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 - ) as batcher: - batcher.append(bulk_mutation) - assert len(batcher._staged_entries) == 1 - batcher.append(bulk_mutation2) - assert len(batcher._flush_jobs) == 0 - CrossSync._Sync_Impl.yield_to_event_loop() - assert len(batcher._staged_entries) == 2 - assert len(batcher._flush_jobs) == 0 - assert self._retrieve_cell_value(table, row_key) == start_value - assert self._retrieve_cell_value(table, row_key2) == start_value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @pytest.mark.parametrize( - "start,increment,expected", - [ - (0, 0, 0), - (0, 1, 1), - (0, -1, -1), - (1, 0, 1), - (0, -100, -100), - (0, 3000, 3000), - (10, 4, 14), - (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), - (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), - (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), - ], - ) - def test_read_modify_write_row_increment( - self, client, table, temp_rows, start, increment, expected - ): - """test read_modify_write_row""" - from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) - rule = IncrementRule(family, qualifier, increment) - result = table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert len(result) == 1 - assert result[0].family == family - assert result[0].qualifier == qualifier - assert int(result[0]) == expected - assert self._retrieve_cell_value(table, row_key) == result[0].value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @pytest.mark.parametrize( - "start,append,expected", - [ - (b"", b"", b""), - ("", "", b""), - (b"abc", b"123", b"abc123"), - (b"abc", "123", b"abc123"), - ("", b"1", b"1"), - (b"abc", "", b"abc"), - (b"hello", b"world", b"helloworld"), - ], - ) - def test_read_modify_write_row_append( - self, client, table, temp_rows, start, append, expected - ): - """test read_modify_write_row""" - from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) - rule = AppendValueRule(family, qualifier, append) - result = table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert len(result) == 1 - assert result[0].family == family - assert result[0].qualifier == qualifier - assert result[0].value == expected - assert self._retrieve_cell_value(table, row_key) == result[0].value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - def test_read_modify_write_row_chained(self, client, table, temp_rows): - """test read_modify_write_row with multiple rules""" - from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule - from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - start_amount = 1 - increment_amount = 10 - temp_rows.add_row( - row_key, value=start_amount, family=family, qualifier=qualifier - ) - rule = [ - IncrementRule(family, qualifier, increment_amount), - AppendValueRule(family, qualifier, "hello"), - AppendValueRule(family, qualifier, "world"), - AppendValueRule(family, qualifier, "!"), - ] - result = table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert result[0].family == family - assert result[0].qualifier == qualifier - assert ( - result[0].value - == (start_amount + increment_amount).to_bytes(8, "big", signed=True) - + b"helloworld!" - ) - assert self._retrieve_cell_value(table, row_key) == result[0].value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @pytest.mark.parametrize( - "start_val,predicate_range,expected_result", - [(1, (0, 2), True), (-1, (0, 2), False)], - ) - def test_check_and_mutate( - self, client, table, temp_rows, start_val, predicate_range, expected_result - ): - """test that check_and_mutate_row works applies the right mutations, and returns the right result""" - from google.cloud.bigtable.data.mutations import SetCell - from google.cloud.bigtable.data.row_filters import ValueRangeFilter - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - temp_rows.add_row(row_key, value=start_val, family=family, qualifier=qualifier) - false_mutation_value = b"false-mutation-value" - false_mutation = SetCell( - family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value - ) - true_mutation_value = b"true-mutation-value" - true_mutation = SetCell( - family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value - ) - predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) - result = table.check_and_mutate_row( - row_key, - predicate, - true_case_mutations=true_mutation, - false_case_mutations=false_mutation, - ) - assert result == expected_result - expected_value = ( - true_mutation_value if expected_result else false_mutation_value - ) - assert self._retrieve_cell_value(table, row_key) == expected_value - - @pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", - ) - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - def test_check_and_mutate_empty_request(self, client, table): - """check_and_mutate with no true or fale mutations should raise an error""" - from google.api_core import exceptions - - with pytest.raises(exceptions.InvalidArgument) as e: - table.check_and_mutate_row( - b"row_key", None, true_case_mutations=None, false_case_mutations=None - ) - assert "No mutations provided" in str(e.value) - - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_read_rows_stream(self, table, temp_rows): - """Ensure that the read_rows_stream method works""" - temp_rows.add_row(b"row_key_1") - temp_rows.add_row(b"row_key_2") - generator = table.read_rows_stream({}) - first_row = generator.__next__() - second_row = generator.__next__() - assert first_row.row_key == b"row_key_1" - assert second_row.row_key == b"row_key_2" - with pytest.raises(CrossSync._Sync_Impl.StopIteration): - generator.__next__() - - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_read_rows(self, table, temp_rows): - """Ensure that the read_rows method works""" - temp_rows.add_row(b"row_key_1") - temp_rows.add_row(b"row_key_2") - row_list = table.read_rows({}) - assert len(row_list) == 2 - assert row_list[0].row_key == b"row_key_1" - assert row_list[1].row_key == b"row_key_2" - - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_read_rows_sharded_simple(self, table, temp_rows): - """Test read rows sharded with two queries""" - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - - temp_rows.add_row(b"a") - temp_rows.add_row(b"b") - temp_rows.add_row(b"c") - temp_rows.add_row(b"d") - query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) - query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) - row_list = table.read_rows_sharded([query1, query2]) - assert len(row_list) == 4 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"c" - assert row_list[2].row_key == b"b" - assert row_list[3].row_key == b"d" - - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_read_rows_sharded_from_sample(self, table, temp_rows): - """Test end-to-end sharding""" - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - from google.cloud.bigtable.data.read_rows_query import RowRange - - temp_rows.add_row(b"a") - temp_rows.add_row(b"b") - temp_rows.add_row(b"c") - temp_rows.add_row(b"d") - table_shard_keys = table.sample_row_keys() - query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) - shard_queries = query.shard(table_shard_keys) - row_list = table.read_rows_sharded(shard_queries) - assert len(row_list) == 3 - assert row_list[0].row_key == b"b" - assert row_list[1].row_key == b"c" - assert row_list[2].row_key == b"d" - - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_read_rows_sharded_filters_limits(self, table, temp_rows): - """Test read rows sharded with filters and limits""" - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - temp_rows.add_row(b"a") - temp_rows.add_row(b"b") - temp_rows.add_row(b"c") - temp_rows.add_row(b"d") - label_filter1 = ApplyLabelFilter("first") - label_filter2 = ApplyLabelFilter("second") - query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) - query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) - row_list = table.read_rows_sharded([query1, query2]) - assert len(row_list) == 3 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"b" - assert row_list[2].row_key == b"d" - assert row_list[0][0].labels == ["first"] - assert row_list[1][0].labels == ["second"] - assert row_list[2][0].labels == ["second"] - - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_read_rows_range_query(self, table, temp_rows): - """Ensure that the read_rows method works""" - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable.data import RowRange - - temp_rows.add_row(b"a") - temp_rows.add_row(b"b") - temp_rows.add_row(b"c") - temp_rows.add_row(b"d") - query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) - row_list = table.read_rows(query) - assert len(row_list) == 2 - assert row_list[0].row_key == b"b" - assert row_list[1].row_key == b"c" - - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_read_rows_single_key_query(self, table, temp_rows): - """Ensure that the read_rows method works with specified query""" - from google.cloud.bigtable.data import ReadRowsQuery - - temp_rows.add_row(b"a") - temp_rows.add_row(b"b") - temp_rows.add_row(b"c") - temp_rows.add_row(b"d") - query = ReadRowsQuery(row_keys=[b"a", b"c"]) - row_list = table.read_rows(query) - assert len(row_list) == 2 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"c" - - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - def test_read_rows_with_filter(self, table, temp_rows): - """ensure filters are applied""" - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - temp_rows.add_row(b"a") - temp_rows.add_row(b"b") - temp_rows.add_row(b"c") - temp_rows.add_row(b"d") - expected_label = "test-label" - row_filter = ApplyLabelFilter(expected_label) - query = ReadRowsQuery(row_filter=row_filter) - row_list = table.read_rows(query) - assert len(row_list) == 4 - for row in row_list: - assert row[0].labels == [expected_label] - - @pytest.mark.usefixtures("table") - def test_read_rows_stream_close(self, table, temp_rows): - """Ensure that the read_rows_stream can be closed""" - from google.cloud.bigtable.data import ReadRowsQuery - - temp_rows.add_row(b"row_key_1") - temp_rows.add_row(b"row_key_2") - query = ReadRowsQuery() - generator = table.read_rows_stream(query) - first_row = generator.__next__() - assert first_row.row_key == b"row_key_1" - generator.close() - with pytest.raises(CrossSync._Sync_Impl.StopIteration): - generator.__next__() - - @pytest.mark.usefixtures("table") - def test_read_row(self, table, temp_rows): - """Test read_row (single row helper)""" - from google.cloud.bigtable.data import Row - - temp_rows.add_row(b"row_key_1", value=b"value") - row = table.read_row(b"row_key_1") - assert isinstance(row, Row) - assert row.row_key == b"row_key_1" - assert row.cells[0].value == b"value" - - @pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", - ) - @pytest.mark.usefixtures("table") - def test_read_row_missing(self, table): - """Test read_row when row does not exist""" - from google.api_core import exceptions - - row_key = "row_key_not_exist" - result = table.read_row(row_key) - assert result is None - with pytest.raises(exceptions.InvalidArgument) as e: - table.read_row("") - assert "Row keys must be non-empty" in str(e) - - @pytest.mark.usefixtures("table") - def test_read_row_w_filter(self, table, temp_rows): - """Test read_row (single row helper)""" - from google.cloud.bigtable.data import Row - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - temp_rows.add_row(b"row_key_1", value=b"value") - expected_label = "test-label" - label_filter = ApplyLabelFilter(expected_label) - row = table.read_row(b"row_key_1", row_filter=label_filter) - assert isinstance(row, Row) - assert row.row_key == b"row_key_1" - assert row.cells[0].value == b"value" - assert row.cells[0].labels == [expected_label] - - @pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", - ) - @pytest.mark.usefixtures("table") - def test_row_exists(self, table, temp_rows): - from google.api_core import exceptions - - "Test row_exists with rows that exist and don't exist" - assert table.row_exists(b"row_key_1") is False - temp_rows.add_row(b"row_key_1") - assert table.row_exists(b"row_key_1") is True - assert table.row_exists("row_key_1") is True - assert table.row_exists(b"row_key_2") is False - assert table.row_exists("row_key_2") is False - assert table.row_exists("3") is False - temp_rows.add_row(b"3") - assert table.row_exists(b"3") is True - with pytest.raises(exceptions.InvalidArgument) as e: - table.row_exists("") - assert "Row keys must be non-empty" in str(e) - - @pytest.mark.usefixtures("table") - @CrossSync._Sync_Impl.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @pytest.mark.parametrize( - "cell_value,filter_input,expect_match", - [ - (b"abc", b"abc", True), - (b"abc", "abc", True), - (b".", ".", True), - (".*", ".*", True), - (".*", b".*", True), - ("a", ".*", False), - (b".*", b".*", True), - ("\\a", "\\a", True), - (b"\xe2\x98\x83", "☃", True), - ("☃", "☃", True), - ("\\C☃", "\\C☃", True), - (1, 1, True), - (2, 1, False), - (68, 68, True), - ("D", 68, False), - (68, "D", False), - (-1, -1, True), - (2852126720, 2852126720, True), - (-1431655766, -1431655766, True), - (-1431655766, -1, False), - ], - ) - def test_literal_value_filter( - self, table, temp_rows, cell_value, filter_input, expect_match - ): - """Literal value filter does complex escaping on re2 strings. - Make sure inputs are properly interpreted by the server""" - from google.cloud.bigtable.data.row_filters import LiteralValueFilter - from google.cloud.bigtable.data import ReadRowsQuery - - f = LiteralValueFilter(filter_input) - temp_rows.add_row(b"row_key_1", value=cell_value) - query = ReadRowsQuery(row_filter=f) - row_list = table.read_rows(query) - assert len(row_list) == bool( - expect_match - ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index db3da531d..85adae2d2 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -282,7 +282,9 @@ async def test_add_to_flow_max_mutation_limits( "google.cloud.bigtable.data._sync.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", max_limit, ) - with async_patch, sync_patch: + subpath = "_async" if CrossSync.is_async else "_sync" + path = f"google.cloud.bigtable.data.{subpath}.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT" + with mock.patch(path, max_limit): mutation_objs = [ self._make_mutation(count=m[0], size=m[1]) for m in mutations ] diff --git a/tests/unit/data/_sync/__init__.py b/tests/unit/data/_sync/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/data/_sync/test__mutate_rows.py b/tests/unit/data/_sync/test__mutate_rows.py deleted file mode 100644 index 63d4009c6..000000000 --- a/tests/unit/data/_sync/test__mutate_rows.py +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is automatically generated by CrossSync. Do not edit manually. -import pytest -from google.cloud.bigtable_v2.types import MutateRowsResponse -from google.rpc import status_pb2 -from google.api_core.exceptions import DeadlineExceeded -from google.api_core.exceptions import Forbidden -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -try: - from unittest import mock -except ImportError: - import mock - - -class TestMutateRowsOperation: - def _target_class(self): - if CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable.data._async._mutate_rows import ( - _MutateRowsOperationAsync, - ) - - return _MutateRowsOperationAsync - else: - from google.cloud.bigtable.data._sync._mutate_rows import ( - _MutateRowsOperation, - ) - - return _MutateRowsOperation - - def _make_one(self, *args, **kwargs): - if not args: - kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) - kwargs["table"] = kwargs.pop("table", CrossSync._Sync_Impl.Mock()) - kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) - kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) - kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) - kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) - return self._target_class()(*args, **kwargs) - - def _make_mutation(self, count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation - - def _mock_stream(self, mutation_list, error_dict): - for idx, entry in enumerate(mutation_list): - code = error_dict.get(idx, 0) - yield MutateRowsResponse( - entries=[ - MutateRowsResponse.Entry( - index=idx, status=status_pb2.Status(code=code) - ) - ] - ) - - def _make_mock_gapic(self, mutation_list, error_dict=None): - mock_fn = CrossSync._Sync_Impl.Mock() - if error_dict is None: - error_dict = {} - mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( - mutation_list, error_dict - ) - return mock_fn - - def test_ctor(self): - """test that constructor sets all the attributes correctly""" - from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto - from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - from google.api_core.exceptions import DeadlineExceeded - from google.api_core.exceptions import Aborted - - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] - operation_timeout = 0.05 - attempt_timeout = 0.01 - retryable_exceptions = () - instance = self._make_one( - client, - table, - entries, - operation_timeout, - attempt_timeout, - retryable_exceptions, - ) - assert client.mutate_rows.call_count == 0 - instance._gapic_fn() - assert client.mutate_rows.call_count == 1 - inner_kwargs = client.mutate_rows.call_args[1] - assert len(inner_kwargs) == 4 - assert inner_kwargs["table_name"] == table.table_name - assert inner_kwargs["app_profile_id"] == table.app_profile_id - assert inner_kwargs["retry"] is None - metadata = inner_kwargs["metadata"] - assert len(metadata) == 1 - assert metadata[0][0] == "x-goog-request-params" - assert str(table.table_name) in metadata[0][1] - assert str(table.app_profile_id) in metadata[0][1] - entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries] - assert instance.mutations == entries_w_pb - assert next(instance.timeout_generator) == attempt_timeout - assert instance.is_retryable is not None - assert instance.is_retryable(DeadlineExceeded("")) is False - assert instance.is_retryable(Aborted("")) is False - assert instance.is_retryable(_MutateRowsIncomplete("")) is True - assert instance.is_retryable(RuntimeError("")) is False - assert instance.remaining_indices == list(range(len(entries))) - assert instance.errors == {} - - def test_ctor_too_many_entries(self): - """should raise an error if an operation is created with more than 100,000 entries""" - from google.cloud.bigtable.data._async._mutate_rows import ( - _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, - ) - - assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100000 - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation()] * (_MUTATE_ROWS_REQUEST_MUTATION_LIMIT + 1) - operation_timeout = 0.05 - attempt_timeout = 0.01 - with pytest.raises(ValueError) as e: - self._make_one(client, table, entries, operation_timeout, attempt_timeout) - assert "mutate_rows requests can contain at most 100000 mutations" in str( - e.value - ) - assert "Found 100001" in str(e.value) - - def test_mutate_rows_operation(self): - """Test successful case of mutate_rows_operation""" - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] - operation_timeout = 0.05 - cls = self._target_class() - with mock.patch( - f"{cls.__module__}.{cls.__name__}._run_attempt", CrossSync._Sync_Impl.Mock() - ) as attempt_mock: - instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout - ) - instance.start() - assert attempt_mock.call_count == 1 - - @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) - def test_mutate_rows_attempt_exception(self, exc_type): - """exceptions raised from attempt should be raised in MutationsExceptionGroup""" - client = CrossSync._Sync_Impl.Mock() - table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] - operation_timeout = 0.05 - expected_exception = exc_type("test") - client.mutate_rows.side_effect = expected_exception - found_exc = None - try: - instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout - ) - instance._run_attempt() - except Exception as e: - found_exc = e - assert client.mutate_rows.call_count == 1 - assert type(found_exc) is exc_type - assert found_exc == expected_exception - assert len(instance.errors) == 2 - assert len(instance.remaining_indices) == 0 - - @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) - def test_mutate_rows_exception(self, exc_type): - """exceptions raised from retryable should be raised in MutationsExceptionGroup""" - from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup - from google.cloud.bigtable.data.exceptions import FailedMutationEntryError - - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] - operation_timeout = 0.05 - expected_cause = exc_type("abort") - with mock.patch.object( - self._target_class(), "_run_attempt", CrossSync._Sync_Impl.Mock() - ) as attempt_mock: - attempt_mock.side_effect = expected_cause - found_exc = None - try: - instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout - ) - instance.start() - except MutationsExceptionGroup as e: - found_exc = e - assert attempt_mock.call_count == 1 - assert len(found_exc.exceptions) == 2 - assert isinstance(found_exc.exceptions[0], FailedMutationEntryError) - assert isinstance(found_exc.exceptions[1], FailedMutationEntryError) - assert found_exc.exceptions[0].__cause__ == expected_cause - assert found_exc.exceptions[1].__cause__ == expected_cause - - @pytest.mark.parametrize("exc_type", [DeadlineExceeded, RuntimeError]) - def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): - """If an exception fails but eventually passes, it should not raise an exception""" - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation()] - operation_timeout = 1 - expected_cause = exc_type("retry") - num_retries = 2 - with mock.patch.object( - self._target_class(), "_run_attempt", CrossSync._Sync_Impl.Mock() - ) as attempt_mock: - attempt_mock.side_effect = [expected_cause] * num_retries + [None] - instance = self._make_one( - client, - table, - entries, - operation_timeout, - operation_timeout, - retryable_exceptions=(exc_type,), - ) - instance.start() - assert attempt_mock.call_count == num_retries + 1 - - def test_mutate_rows_incomplete_ignored(self): - """MutateRowsIncomplete exceptions should not be added to error list""" - from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup - from google.api_core.exceptions import DeadlineExceeded - - client = mock.Mock() - table = mock.Mock() - entries = [self._make_mutation()] - operation_timeout = 0.05 - with mock.patch.object( - self._target_class(), "_run_attempt", CrossSync._Sync_Impl.Mock() - ) as attempt_mock: - attempt_mock.side_effect = _MutateRowsIncomplete("ignored") - found_exc = None - try: - instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout - ) - instance.start() - except MutationsExceptionGroup as e: - found_exc = e - assert attempt_mock.call_count > 0 - assert len(found_exc.exceptions) == 1 - assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) - - def test_run_attempt_single_entry_success(self): - """Test mutating a single entry""" - mutation = self._make_mutation() - expected_timeout = 1.3 - mock_gapic_fn = self._make_mock_gapic({0: mutation}) - instance = self._make_one( - mutation_entries=[mutation], attempt_timeout=expected_timeout - ) - with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): - instance._run_attempt() - assert len(instance.remaining_indices) == 0 - assert mock_gapic_fn.call_count == 1 - _, kwargs = mock_gapic_fn.call_args - assert kwargs["timeout"] == expected_timeout - assert kwargs["entries"] == [mutation._to_pb()] - - def test_run_attempt_empty_request(self): - """Calling with no mutations should result in no API calls""" - mock_gapic_fn = self._make_mock_gapic([]) - instance = self._make_one(mutation_entries=[]) - instance._run_attempt() - assert mock_gapic_fn.call_count == 0 - - def test_run_attempt_partial_success_retryable(self): - """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" - from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - - success_mutation = self._make_mutation() - success_mutation_2 = self._make_mutation() - failure_mutation = self._make_mutation() - mutations = [success_mutation, failure_mutation, success_mutation_2] - mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) - instance = self._make_one(mutation_entries=mutations) - instance.is_retryable = lambda x: True - with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): - with pytest.raises(_MutateRowsIncomplete): - instance._run_attempt() - assert instance.remaining_indices == [1] - assert 0 not in instance.errors - assert len(instance.errors[1]) == 1 - assert instance.errors[1][0].grpc_status_code == 300 - assert 2 not in instance.errors - - def test_run_attempt_partial_success_non_retryable(self): - """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" - success_mutation = self._make_mutation() - success_mutation_2 = self._make_mutation() - failure_mutation = self._make_mutation() - mutations = [success_mutation, failure_mutation, success_mutation_2] - mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) - instance = self._make_one(mutation_entries=mutations) - instance.is_retryable = lambda x: False - with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): - instance._run_attempt() - assert instance.remaining_indices == [] - assert 0 not in instance.errors - assert len(instance.errors[1]) == 1 - assert instance.errors[1][0].grpc_status_code == 300 - assert 2 not in instance.errors diff --git a/tests/unit/data/_sync/test__read_rows.py b/tests/unit/data/_sync/test__read_rows.py deleted file mode 100644 index 296e8e7f9..000000000 --- a/tests/unit/data/_sync/test__read_rows.py +++ /dev/null @@ -1,360 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is automatically generated by CrossSync. Do not edit manually. -import pytest -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -if CrossSync._Sync_Impl.is_async: - pass -else: - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation -try: - from unittest import mock -except ImportError: - import mock - - -class TestReadRowsOperation: - """ - Tests helper functions in the ReadRowsOperation class - in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt - is tested in test_read_rows_acceptance test_client_read_rows, and conformance tests - """ - - @staticmethod - def _get_target_class(): - return _ReadRowsOperation - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - - def test_ctor(self): - from google.cloud.bigtable.data import ReadRowsQuery - - row_limit = 91 - query = ReadRowsQuery(limit=row_limit) - client = mock.Mock() - client.read_rows = mock.Mock() - client.read_rows.return_value = None - table = mock.Mock() - table._client = client - table.table_name = "test_table" - table.app_profile_id = "test_profile" - expected_operation_timeout = 42 - expected_request_timeout = 44 - time_gen_mock = mock.Mock() - with mock.patch( - "google.cloud.bigtable.data._helpers._attempt_timeout_generator", - time_gen_mock, - ): - instance = self._make_one( - query, - table, - operation_timeout=expected_operation_timeout, - attempt_timeout=expected_request_timeout, - ) - assert time_gen_mock.call_count == 1 - time_gen_mock.assert_called_once_with( - expected_request_timeout, expected_operation_timeout - ) - assert instance._last_yielded_row_key is None - assert instance._remaining_count == row_limit - assert instance.operation_timeout == expected_operation_timeout - assert client.read_rows.call_count == 0 - assert instance._metadata == [ - ( - "x-goog-request-params", - "table_name=test_table&app_profile_id=test_profile", - ) - ] - assert instance.request.table_name == table.table_name - assert instance.request.app_profile_id == table.app_profile_id - assert instance.request.rows_limit == row_limit - - @pytest.mark.parametrize( - "in_keys,last_key,expected", - [ - (["b", "c", "d"], "a", ["b", "c", "d"]), - (["a", "b", "c"], "b", ["c"]), - (["a", "b", "c"], "c", []), - (["a", "b", "c"], "d", []), - (["d", "c", "b", "a"], "b", ["d", "c"]), - ], - ) - @pytest.mark.parametrize("with_range", [True, False]) - def test_revise_request_rowset_keys_with_range( - self, in_keys, last_key, expected, with_range - ): - from google.cloud.bigtable_v2.types import RowSet as RowSetPB - from google.cloud.bigtable_v2.types import RowRange as RowRangePB - from google.cloud.bigtable.data.exceptions import _RowSetComplete - - in_keys = [key.encode("utf-8") for key in in_keys] - expected = [key.encode("utf-8") for key in expected] - last_key = last_key.encode("utf-8") - if with_range: - sample_range = [RowRangePB(start_key_open=last_key)] - else: - sample_range = [] - row_set = RowSetPB(row_keys=in_keys, row_ranges=sample_range) - if not with_range and expected == []: - with pytest.raises(_RowSetComplete): - self._get_target_class()._revise_request_rowset(row_set, last_key) - else: - revised = self._get_target_class()._revise_request_rowset(row_set, last_key) - assert revised.row_keys == expected - assert revised.row_ranges == sample_range - - @pytest.mark.parametrize( - "in_ranges,last_key,expected", - [ - ( - [{"start_key_open": "b", "end_key_closed": "d"}], - "a", - [{"start_key_open": "b", "end_key_closed": "d"}], - ), - ( - [{"start_key_closed": "b", "end_key_closed": "d"}], - "a", - [{"start_key_closed": "b", "end_key_closed": "d"}], - ), - ( - [{"start_key_open": "a", "end_key_closed": "d"}], - "b", - [{"start_key_open": "b", "end_key_closed": "d"}], - ), - ( - [{"start_key_closed": "a", "end_key_open": "d"}], - "b", - [{"start_key_open": "b", "end_key_open": "d"}], - ), - ( - [{"start_key_closed": "b", "end_key_closed": "d"}], - "b", - [{"start_key_open": "b", "end_key_closed": "d"}], - ), - ([{"start_key_closed": "b", "end_key_closed": "d"}], "d", []), - ([{"start_key_closed": "b", "end_key_open": "d"}], "d", []), - ([{"start_key_closed": "b", "end_key_closed": "d"}], "e", []), - ([{"start_key_closed": "b"}], "z", [{"start_key_open": "z"}]), - ([{"start_key_closed": "b"}], "a", [{"start_key_closed": "b"}]), - ( - [{"end_key_closed": "z"}], - "a", - [{"start_key_open": "a", "end_key_closed": "z"}], - ), - ( - [{"end_key_open": "z"}], - "a", - [{"start_key_open": "a", "end_key_open": "z"}], - ), - ], - ) - @pytest.mark.parametrize("with_key", [True, False]) - def test_revise_request_rowset_ranges( - self, in_ranges, last_key, expected, with_key - ): - from google.cloud.bigtable_v2.types import RowSet as RowSetPB - from google.cloud.bigtable_v2.types import RowRange as RowRangePB - from google.cloud.bigtable.data.exceptions import _RowSetComplete - - next_key = (last_key + "a").encode("utf-8") - last_key = last_key.encode("utf-8") - in_ranges = [ - RowRangePB(**{k: v.encode("utf-8") for k, v in r.items()}) - for r in in_ranges - ] - expected = [ - RowRangePB(**{k: v.encode("utf-8") for k, v in r.items()}) for r in expected - ] - if with_key: - row_keys = [next_key] - else: - row_keys = [] - row_set = RowSetPB(row_ranges=in_ranges, row_keys=row_keys) - if not with_key and expected == []: - with pytest.raises(_RowSetComplete): - self._get_target_class()._revise_request_rowset(row_set, last_key) - else: - revised = self._get_target_class()._revise_request_rowset(row_set, last_key) - assert revised.row_keys == row_keys - assert revised.row_ranges == expected - - @pytest.mark.parametrize("last_key", ["a", "b", "c"]) - def test_revise_request_full_table(self, last_key): - from google.cloud.bigtable_v2.types import RowSet as RowSetPB - from google.cloud.bigtable_v2.types import RowRange as RowRangePB - - last_key = last_key.encode("utf-8") - row_set = RowSetPB() - for selected_set in [row_set, None]: - revised = self._get_target_class()._revise_request_rowset( - selected_set, last_key - ) - assert revised.row_keys == [] - assert len(revised.row_ranges) == 1 - assert revised.row_ranges[0] == RowRangePB(start_key_open=last_key) - - def test_revise_to_empty_rowset(self): - """revising to an empty rowset should raise error""" - from google.cloud.bigtable.data.exceptions import _RowSetComplete - from google.cloud.bigtable_v2.types import RowSet as RowSetPB - from google.cloud.bigtable_v2.types import RowRange as RowRangePB - - row_keys = [b"a", b"b", b"c"] - row_range = RowRangePB(end_key_open=b"c") - row_set = RowSetPB(row_keys=row_keys, row_ranges=[row_range]) - with pytest.raises(_RowSetComplete): - self._get_target_class()._revise_request_rowset(row_set, b"d") - - @pytest.mark.parametrize( - "start_limit,emit_num,expected_limit", - [ - (10, 0, 10), - (10, 1, 9), - (10, 10, 0), - (None, 10, None), - (None, 0, None), - (4, 2, 2), - ], - ) - def test_revise_limit(self, start_limit, emit_num, expected_limit): - """revise_limit should revise the request's limit field - - if limit is 0 (unlimited), it should never be revised - - if start_limit-emit_num == 0, the request should end early - - if the number emitted exceeds the new limit, an exception should - should be raised (tested in test_revise_limit_over_limit)""" - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable_v2.types import ReadRowsResponse - - def awaitable_stream(): - def mock_stream(): - for i in range(emit_num): - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk( - row_key=str(i).encode(), - family_name="b", - qualifier=b"c", - value=b"d", - commit_row=True, - ) - ] - ) - - return mock_stream() - - query = ReadRowsQuery(limit=start_limit) - table = mock.Mock() - table.table_name = "table_name" - table.app_profile_id = "app_profile_id" - instance = self._make_one(query, table, 10, 10) - assert instance._remaining_count == start_limit - for val in instance.chunk_stream(awaitable_stream()): - pass - assert instance._remaining_count == expected_limit - - @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) - def test_revise_limit_over_limit(self, start_limit, emit_num): - """Should raise runtime error if we get in state where emit_num > start_num - (unless start_num == 0, which represents unlimited)""" - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable_v2.types import ReadRowsResponse - from google.cloud.bigtable.data.exceptions import InvalidChunk - - def awaitable_stream(): - def mock_stream(): - for i in range(emit_num): - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk( - row_key=str(i).encode(), - family_name="b", - qualifier=b"c", - value=b"d", - commit_row=True, - ) - ] - ) - - return mock_stream() - - query = ReadRowsQuery(limit=start_limit) - table = mock.Mock() - table.table_name = "table_name" - table.app_profile_id = "app_profile_id" - instance = self._make_one(query, table, 10, 10) - assert instance._remaining_count == start_limit - with pytest.raises(InvalidChunk) as e: - for val in instance.chunk_stream(awaitable_stream()): - pass - assert "emit count exceeds row limit" in str(e.value) - - def test_close(self): - """should be able to close a stream safely with aclose. - Closed generators should raise StopAsyncIteration on next yield""" - - def mock_stream(): - while True: - yield 1 - - with mock.patch.object( - self._get_target_class(), "_read_rows_attempt" - ) as mock_attempt: - instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) - wrapped_gen = mock_stream() - mock_attempt.return_value = wrapped_gen - gen = instance.start_operation() - gen.__next__() - gen.close() - with pytest.raises(CrossSync._Sync_Impl.StopIteration): - gen.__next__() - gen.close() - with pytest.raises(CrossSync._Sync_Impl.StopIteration): - wrapped_gen.__next__() - - def test_retryable_ignore_repeated_rows(self): - """Duplicate rows should cause an invalid chunk error""" - from google.cloud.bigtable.data.exceptions import InvalidChunk - from google.cloud.bigtable_v2.types import ReadRowsResponse - - row_key = b"duplicate" - - def mock_awaitable_stream(): - def mock_stream(): - while True: - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) - ] - ) - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) - ] - ) - - return mock_stream() - - instance = mock.Mock() - instance._last_yielded_row_key = None - instance._remaining_count = None - stream = self._get_target_class().chunk_stream( - instance, mock_awaitable_stream() - ) - stream.__next__() - with pytest.raises(InvalidChunk) as exc: - stream.__next__() - assert "row keys should be strictly increasing" in str(exc.value) diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py deleted file mode 100644 index a6415a0d2..000000000 --- a/tests/unit/data/_sync/test_client.py +++ /dev/null @@ -1,2740 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is automatically generated by CrossSync. Do not edit manually. -from __future__ import annotations -import grpc -import asyncio -import re -import pytest -from google.cloud.bigtable.data import mutations -from google.auth.credentials import AnonymousCredentials -from google.cloud.bigtable_v2.types import ReadRowsResponse -from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery -from google.api_core import exceptions as core_exceptions -from google.cloud.bigtable.data.exceptions import InvalidChunk -from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete -from google.cloud.bigtable.data import TABLE_DEFAULT -from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule -from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -try: - from unittest import mock -except ImportError: - import mock -if CrossSync._Sync_Impl.is_async: - pass -else: - from google.api_core import grpc_helpers - from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledBigtableGrpcTransport, - ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledChannel, - ) - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - from google.cloud.bigtable.data._sync.client import Table, BigtableDataClient - - -class TestBigtableDataClient: - @staticmethod - def _get_target_class(): - return BigtableDataClient - - @classmethod - def _make_client(cls, *args, use_emulator=True, **kwargs): - import os - - env_mask = {} - if use_emulator: - env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" - import warnings - - warnings.filterwarnings("ignore", category=RuntimeWarning) - else: - kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) - kwargs["project"] = kwargs.get("project", "project-id") - with mock.patch.dict(os.environ, env_mask): - return cls._get_target_class()(*args, **kwargs) - - def test_ctor(self): - expected_project = "project-id" - expected_pool_size = 11 - expected_credentials = AnonymousCredentials() - client = self._make_client( - project="project-id", - pool_size=expected_pool_size, - credentials=expected_credentials, - use_emulator=False, - ) - CrossSync._Sync_Impl.yield_to_event_loop() - assert client.project == expected_project - assert len(client.transport._grpc_channel._pool) == expected_pool_size - assert not client._active_instances - assert len(client._channel_refresh_tasks) == expected_pool_size - assert client.transport._credentials == expected_credentials - client.close() - - def test_ctor_super_inits(self): - from google.cloud.client import ClientWithProject - from google.api_core import client_options as client_options_lib - from google.cloud.bigtable import __version__ as bigtable_version - - project = "project-id" - pool_size = 11 - credentials = AnonymousCredentials() - client_options = {"api_endpoint": "foo.bar:1234"} - options_parsed = client_options_lib.from_dict(client_options) - asyncio_portion = "-async" if CrossSync._Sync_Impl.is_async else "" - transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" - with mock.patch.object(BigtableClient, "__init__") as bigtable_client_init: - bigtable_client_init.return_value = None - with mock.patch.object( - ClientWithProject, "__init__" - ) as client_project_init: - client_project_init.return_value = None - try: - self._make_client( - project=project, - pool_size=pool_size, - credentials=credentials, - client_options=options_parsed, - use_emulator=False, - ) - except AttributeError: - pass - assert bigtable_client_init.call_count == 1 - kwargs = bigtable_client_init.call_args[1] - assert kwargs["transport"] == transport_str - assert kwargs["credentials"] == credentials - assert kwargs["client_options"] == options_parsed - assert client_project_init.call_count == 1 - kwargs = client_project_init.call_args[1] - assert kwargs["project"] == project - assert kwargs["credentials"] == credentials - assert kwargs["client_options"] == options_parsed - - def test_ctor_dict_options(self): - from google.api_core.client_options import ClientOptions - - client_options = {"api_endpoint": "foo.bar:1234"} - with mock.patch.object(BigtableClient, "__init__") as bigtable_client_init: - try: - self._make_client(client_options=client_options) - except TypeError: - pass - bigtable_client_init.assert_called_once() - kwargs = bigtable_client_init.call_args[1] - called_options = kwargs["client_options"] - assert called_options.api_endpoint == "foo.bar:1234" - assert isinstance(called_options, ClientOptions) - with mock.patch.object( - self._get_target_class(), "_start_background_channel_refresh" - ) as start_background_refresh: - client = self._make_client( - client_options=client_options, use_emulator=False - ) - start_background_refresh.assert_called_once() - client.close() - - def test_veneer_grpc_headers(self): - client_component = "data-async" if CrossSync._Sync_Impl.is_async else "data" - VENEER_HEADER_REGEX = re.compile( - "gapic\\/[0-9]+\\.[\\w.-]+ gax\\/[0-9]+\\.[\\w.-]+ gccl\\/[0-9]+\\.[\\w.-]+-" - + client_component - + " gl-python\\/[0-9]+\\.[\\w.-]+ grpc\\/[0-9]+\\.[\\w.-]+" - ) - if CrossSync._Sync_Impl.is_async: - patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") - else: - patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") - with patch as gapic_mock: - client = self._make_client(project="project-id") - wrapped_call_list = gapic_mock.call_args_list - assert len(wrapped_call_list) > 0 - for call in wrapped_call_list: - client_info = call.kwargs["client_info"] - assert client_info is not None, f"{call} has no client_info" - wrapped_user_agent_sorted = " ".join( - sorted(client_info.to_user_agent().split(" ")) - ) - assert VENEER_HEADER_REGEX.match( - wrapped_user_agent_sorted - ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" - client.close() - - def test_channel_pool_creation(self): - pool_size = 14 - with mock.patch.object( - grpc_helpers, "create_channel", CrossSync._Sync_Impl.Mock() - ) as create_channel: - client = self._make_client(project="project-id", pool_size=pool_size) - assert create_channel.call_count == pool_size - client.close() - client = self._make_client(project="project-id", pool_size=pool_size) - pool_list = list(client.transport._grpc_channel._pool) - pool_set = set(client.transport._grpc_channel._pool) - assert len(pool_list) == len(pool_set) - client.close() - - def test_channel_pool_rotation(self): - pool_size = 7 - with mock.patch.object(PooledChannel, "next_channel") as next_channel: - client = self._make_client(project="project-id", pool_size=pool_size) - assert len(client.transport._grpc_channel._pool) == pool_size - next_channel.reset_mock() - with mock.patch.object( - type(client.transport._grpc_channel._pool[0]), "unary_unary" - ) as unary_unary: - channel_next = None - for i in range(pool_size): - channel_last = channel_next - channel_next = client.transport.grpc_channel._pool[i] - assert channel_last != channel_next - next_channel.return_value = channel_next - client.transport.ping_and_warm() - assert next_channel.call_count == i + 1 - unary_unary.assert_called_once() - unary_unary.reset_mock() - client.close() - - def test_channel_pool_replace(self): - import time - - sleep_module = asyncio if CrossSync._Sync_Impl.is_async else time - with mock.patch.object(sleep_module, "sleep"): - pool_size = 7 - client = self._make_client(project="project-id", pool_size=pool_size) - for replace_idx in range(pool_size): - start_pool = [ - channel for channel in client.transport._grpc_channel._pool - ] - grace_period = 9 - with mock.patch.object( - type(client.transport._grpc_channel._pool[-1]), "close" - ) as close: - new_channel = client.transport.create_channel() - client.transport.replace_channel( - replace_idx, grace=grace_period, new_channel=new_channel - ) - close.assert_called_once() - if CrossSync._Sync_Impl.is_async: - close.assert_called_once_with(grace=grace_period) - close.assert_awaited_once() - assert client.transport._grpc_channel._pool[replace_idx] == new_channel - for i in range(pool_size): - if i != replace_idx: - assert client.transport._grpc_channel._pool[i] == start_pool[i] - else: - assert client.transport._grpc_channel._pool[i] != start_pool[i] - client.close() - - def test__start_background_channel_refresh_tasks_exist(self): - client = self._make_client(project="project-id", use_emulator=False) - assert len(client._channel_refresh_tasks) > 0 - with mock.patch.object(asyncio, "create_task") as create_task: - client._start_background_channel_refresh() - create_task.assert_not_called() - client.close() - - @pytest.mark.parametrize("pool_size", [1, 3, 7]) - def test__start_background_channel_refresh(self, pool_size): - import concurrent.futures - - with mock.patch.object( - self._get_target_class(), - "_ping_and_warm_instances", - CrossSync._Sync_Impl.Mock(), - ) as ping_and_warm: - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - client._start_background_channel_refresh() - assert len(client._channel_refresh_tasks) == pool_size - for task in client._channel_refresh_tasks: - if CrossSync._Sync_Impl.is_async: - assert isinstance(task, asyncio.Task) - else: - assert isinstance(task, concurrent.futures.Future) - if CrossSync._Sync_Impl.is_async: - asyncio.sleep(0.1) - assert ping_and_warm.call_count == pool_size - for channel in client.transport._grpc_channel._pool: - ping_and_warm.assert_any_call(channel) - client.close() - - def test__ping_and_warm_instances(self): - """test ping and warm with mocked asyncio.gather""" - client_mock = mock.Mock() - client_mock._execute_ping_and_warms = ( - lambda *args: self._get_target_class()._execute_ping_and_warms( - client_mock, *args - ) - ) - with mock.patch.object( - CrossSync._Sync_Impl, "gather_partials", CrossSync._Sync_Impl.Mock() - ) as gather: - gather.side_effect = lambda partials, **kwargs: [None for _ in partials] - channel = mock.Mock() - client_mock._active_instances = [] - result = self._get_target_class()._ping_and_warm_instances( - client_mock, channel - ) - assert len(result) == 0 - assert gather.call_args.kwargs["return_exceptions"] is True - assert gather.call_args.kwargs["sync_executor"] == client_mock._executor - client_mock._active_instances = [ - (mock.Mock(), mock.Mock(), mock.Mock()) - ] * 4 - gather.reset_mock() - channel.reset_mock() - result = self._get_target_class()._ping_and_warm_instances( - client_mock, channel - ) - assert len(result) == 4 - gather.assert_called_once() - partial_list = gather.call_args.args[0] - assert len(partial_list) == 4 - if CrossSync._Sync_Impl.is_async: - gather.assert_awaited_once() - grpc_call_args = channel.unary_unary().call_args_list - for idx, (_, kwargs) in enumerate(grpc_call_args): - ( - expected_instance, - expected_table, - expected_app_profile, - ) = client_mock._active_instances[idx] - request = kwargs["request"] - assert request["name"] == expected_instance - assert request["app_profile_id"] == expected_app_profile - metadata = kwargs["metadata"] - assert len(metadata) == 1 - assert metadata[0][0] == "x-goog-request-params" - assert ( - metadata[0][1] - == f"name={expected_instance}&app_profile_id={expected_app_profile}" - ) - - def test__ping_and_warm_single_instance(self): - """should be able to call ping and warm with single instance""" - client_mock = mock.Mock() - client_mock._execute_ping_and_warms = ( - lambda *args: self._get_target_class()._execute_ping_and_warms( - client_mock, *args - ) - ) - with mock.patch.object( - CrossSync._Sync_Impl, "gather_partials", CrossSync._Sync_Impl.Mock() - ) as gather: - gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] - channel = mock.Mock() - client_mock._active_instances = [mock.Mock()] * 100 - test_key = ("test-instance", "test-table", "test-app-profile") - result = self._get_target_class()._ping_and_warm_instances( - client_mock, channel, test_key - ) - assert len(result) == 1 - grpc_call_args = channel.unary_unary().call_args_list - assert len(grpc_call_args) == 1 - kwargs = grpc_call_args[0][1] - request = kwargs["request"] - assert request["name"] == "test-instance" - assert request["app_profile_id"] == "test-app-profile" - metadata = kwargs["metadata"] - assert len(metadata) == 1 - assert metadata[0][0] == "x-goog-request-params" - assert ( - metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" - ) - - @pytest.mark.parametrize( - "refresh_interval, wait_time, expected_sleep", - [(0, 0, 0), (0, 1, 0), (10, 0, 10), (10, 5, 5), (10, 10, 0), (10, 15, 0)], - ) - def test__manage_channel_first_sleep( - self, refresh_interval, wait_time, expected_sleep - ): - import time - - with mock.patch.object(time, "monotonic") as monotonic: - monotonic.return_value = 0 - with mock.patch.object(CrossSync._Sync_Impl, "event_wait") as sleep: - sleep.side_effect = asyncio.CancelledError - try: - client = self._make_client(project="project-id") - client._channel_init_time = -wait_time - client._manage_channel(0, refresh_interval, refresh_interval) - except asyncio.CancelledError: - pass - sleep.assert_called_once() - call_time = sleep.call_args[0][1] - assert ( - abs(call_time - expected_sleep) < 0.1 - ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" - client.close() - - def test__manage_channel_ping_and_warm(self): - """_manage channel should call ping and warm internally""" - import time - import threading - - client_mock = mock.Mock() - client_mock._is_closed.is_set.return_value = False - client_mock._channel_init_time = time.monotonic() - channel_list = [mock.Mock(), mock.Mock()] - client_mock.transport.channels = channel_list - new_channel = mock.Mock() - client_mock.transport.grpc_channel._create_channel.return_value = new_channel - sleep_tuple = ( - (asyncio, "sleep") - if CrossSync._Sync_Impl.is_async - else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple): - client_mock.transport.replace_channel.side_effect = asyncio.CancelledError - ping_and_warm = ( - client_mock._ping_and_warm_instances - ) = CrossSync._Sync_Impl.Mock() - try: - channel_idx = 1 - self._get_target_class()._manage_channel(client_mock, channel_idx, 10) - except asyncio.CancelledError: - pass - assert ping_and_warm.call_count == 2 - assert client_mock.transport.replace_channel.call_count == 1 - old_channel = channel_list[channel_idx] - assert old_channel != new_channel - called_with = [call[0][0] for call in ping_and_warm.call_args_list] - assert old_channel in called_with - assert new_channel in called_with - ping_and_warm.reset_mock() - try: - self._get_target_class()._manage_channel(client_mock, 0, 0, 0) - except asyncio.CancelledError: - pass - ping_and_warm.assert_called_once_with(new_channel) - - @pytest.mark.parametrize( - "refresh_interval, num_cycles, expected_sleep", - [(None, 1, 60 * 35), (10, 10, 100), (10, 1, 10)], - ) - def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sleep): - import time - import random - import threading - - channel_idx = 1 - with mock.patch.object(random, "uniform") as uniform: - uniform.side_effect = lambda min_, max_: min_ - with mock.patch.object(time, "time") as time_mock: - time_mock.return_value = 0 - sleep_tuple = ( - (asyncio, "sleep") - if CrossSync._Sync_Impl.is_async - else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: - sleep.side_effect = [None for i in range(num_cycles - 1)] + [ - asyncio.CancelledError - ] - client = self._make_client(project="project-id") - with mock.patch.object(client.transport, "replace_channel"): - try: - if refresh_interval is not None: - client._manage_channel( - channel_idx, refresh_interval, refresh_interval - ) - else: - client._manage_channel(channel_idx) - except asyncio.CancelledError: - pass - assert sleep.call_count == num_cycles - if CrossSync._Sync_Impl.is_async: - total_sleep = sum([call[0][0] for call in sleep.call_args_list]) - else: - total_sleep = sum( - [call[1]["timeout"] for call in sleep.call_args_list] - ) - assert ( - abs(total_sleep - expected_sleep) < 0.1 - ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" - client.close() - - def test__manage_channel_random(self): - import random - import threading - - sleep_tuple = ( - (asyncio, "sleep") - if CrossSync._Sync_Impl.is_async - else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: - with mock.patch.object(random, "uniform") as uniform: - uniform.return_value = 0 - try: - uniform.side_effect = asyncio.CancelledError - client = self._make_client(project="project-id", pool_size=1) - except asyncio.CancelledError: - uniform.side_effect = None - uniform.reset_mock() - sleep.reset_mock() - min_val = 200 - max_val = 205 - uniform.side_effect = lambda min_, max_: min_ - sleep.side_effect = [None, None, asyncio.CancelledError] - try: - with mock.patch.object(client.transport, "replace_channel"): - client._manage_channel(0, min_val, max_val) - except asyncio.CancelledError: - pass - assert uniform.call_count == 3 - uniform_args = [call[0] for call in uniform.call_args_list] - for found_min, found_max in uniform_args: - assert found_min == min_val - assert found_max == max_val - - @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) - def test__manage_channel_refresh(self, num_cycles): - import threading - - expected_grace = 9 - expected_refresh = 0.5 - channel_idx = 1 - grpc_lib = grpc.aio if CrossSync._Sync_Impl.is_async else grpc - new_channel = grpc_lib.insecure_channel("localhost:8080") - with mock.patch.object( - PooledBigtableGrpcTransport, "replace_channel" - ) as replace_channel: - sleep_tuple = ( - (asyncio, "sleep") - if CrossSync._Sync_Impl.is_async - else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: - sleep.side_effect = [None for i in range(num_cycles)] + [ - asyncio.CancelledError - ] - with mock.patch.object( - grpc_helpers, "create_channel" - ) as create_channel: - create_channel.return_value = new_channel - with mock.patch.object( - self._get_target_class(), "_start_background_channel_refresh" - ): - client = self._make_client( - project="project-id", use_emulator=False - ) - create_channel.reset_mock() - try: - client._manage_channel( - channel_idx, - refresh_interval_min=expected_refresh, - refresh_interval_max=expected_refresh, - grace_period=expected_grace, - ) - except asyncio.CancelledError: - pass - assert sleep.call_count == num_cycles + 1 - assert create_channel.call_count == num_cycles - assert replace_channel.call_count == num_cycles - for call in replace_channel.call_args_list: - args, kwargs = call - assert args[0] == channel_idx - assert kwargs["grace"] == expected_grace - assert kwargs["new_channel"] == new_channel - client.close() - - def test__register_instance(self): - """test instance registration""" - client_mock = mock.Mock() - client_mock._gapic_client.instance_path.side_effect = lambda a, b: f"prefix/{b}" - active_instances = set() - instance_owners = {} - client_mock._active_instances = active_instances - client_mock._instance_owners = instance_owners - client_mock._channel_refresh_tasks = [] - client_mock._start_background_channel_refresh.side_effect = ( - lambda: client_mock._channel_refresh_tasks.append(mock.Mock) - ) - mock_channels = [mock.Mock() for i in range(5)] - client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() - table_mock = mock.Mock() - self._get_target_class()._register_instance( - client_mock, "instance-1", table_mock - ) - assert client_mock._start_background_channel_refresh.call_count == 1 - expected_key = ( - "prefix/instance-1", - table_mock.table_name, - table_mock.app_profile_id, - ) - assert len(active_instances) == 1 - assert expected_key == tuple(list(active_instances)[0]) - assert len(instance_owners) == 1 - assert expected_key == tuple(list(instance_owners)[0]) - assert client_mock._channel_refresh_tasks - table_mock2 = mock.Mock() - self._get_target_class()._register_instance( - client_mock, "instance-2", table_mock2 - ) - assert client_mock._start_background_channel_refresh.call_count == 1 - assert client_mock._ping_and_warm_instances.call_count == len(mock_channels) - for channel in mock_channels: - assert channel in [ - call[0][0] - for call in client_mock._ping_and_warm_instances.call_args_list - ] - assert len(active_instances) == 2 - assert len(instance_owners) == 2 - expected_key2 = ( - "prefix/instance-2", - table_mock2.table_name, - table_mock2.app_profile_id, - ) - assert any( - [ - expected_key2 == tuple(list(active_instances)[i]) - for i in range(len(active_instances)) - ] - ) - assert any( - [ - expected_key2 == tuple(list(instance_owners)[i]) - for i in range(len(instance_owners)) - ] - ) - - @pytest.mark.parametrize( - "insert_instances,expected_active,expected_owner_keys", - [ - ([("i", "t", None)], [("i", "t", None)], [("i", "t", None)]), - ([("i", "t", "p")], [("i", "t", "p")], [("i", "t", "p")]), - ([("1", "t", "p"), ("1", "t", "p")], [("1", "t", "p")], [("1", "t", "p")]), - ( - [("1", "t", "p"), ("2", "t", "p")], - [("1", "t", "p"), ("2", "t", "p")], - [("1", "t", "p"), ("2", "t", "p")], - ), - ], - ) - def test__register_instance_state( - self, insert_instances, expected_active, expected_owner_keys - ): - """test that active_instances and instance_owners are updated as expected""" - client_mock = mock.Mock() - client_mock._gapic_client.instance_path.side_effect = lambda a, b: b - active_instances = set() - instance_owners = {} - client_mock._active_instances = active_instances - client_mock._instance_owners = instance_owners - client_mock._channel_refresh_tasks = [] - client_mock._start_background_channel_refresh.side_effect = ( - lambda: client_mock._channel_refresh_tasks.append(mock.Mock) - ) - mock_channels = [mock.Mock() for i in range(5)] - client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() - table_mock = mock.Mock() - for instance, table, profile in insert_instances: - table_mock.table_name = table - table_mock.app_profile_id = profile - self._get_target_class()._register_instance( - client_mock, instance, table_mock - ) - assert len(active_instances) == len(expected_active) - assert len(instance_owners) == len(expected_owner_keys) - for expected in expected_active: - assert any( - [ - expected == tuple(list(active_instances)[i]) - for i in range(len(active_instances)) - ] - ) - for expected in expected_owner_keys: - assert any( - [ - expected == tuple(list(instance_owners)[i]) - for i in range(len(instance_owners)) - ] - ) - - def test__remove_instance_registration(self): - client = self._make_client(project="project-id") - table = mock.Mock() - client._register_instance("instance-1", table) - client._register_instance("instance-2", table) - assert len(client._active_instances) == 2 - assert len(client._instance_owners.keys()) == 2 - instance_1_path = client._gapic_client.instance_path( - client.project, "instance-1" - ) - instance_1_key = (instance_1_path, table.table_name, table.app_profile_id) - instance_2_path = client._gapic_client.instance_path( - client.project, "instance-2" - ) - instance_2_key = (instance_2_path, table.table_name, table.app_profile_id) - assert len(client._instance_owners[instance_1_key]) == 1 - assert list(client._instance_owners[instance_1_key])[0] == id(table) - assert len(client._instance_owners[instance_2_key]) == 1 - assert list(client._instance_owners[instance_2_key])[0] == id(table) - success = client._remove_instance_registration("instance-1", table) - assert success - assert len(client._active_instances) == 1 - assert len(client._instance_owners[instance_1_key]) == 0 - assert len(client._instance_owners[instance_2_key]) == 1 - assert client._active_instances == {instance_2_key} - success = client._remove_instance_registration("fake-key", table) - assert not success - assert len(client._active_instances) == 1 - client.close() - - def test__multiple_table_registration(self): - """registering with multiple tables with the same key should - add multiple owners to instance_owners, but only keep one copy - of shared key in active_instances""" - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - - with self._make_client(project="project-id") as client: - with client.get_table("instance_1", "table_1") as table_1: - instance_1_path = client._gapic_client.instance_path( - client.project, "instance_1" - ) - instance_1_key = _WarmedInstanceKey( - instance_1_path, table_1.table_name, table_1.app_profile_id - ) - assert len(client._instance_owners[instance_1_key]) == 1 - assert len(client._active_instances) == 1 - assert id(table_1) in client._instance_owners[instance_1_key] - with client.get_table("instance_1", "table_1") as table_2: - assert table_2._register_instance_future is not None - if not CrossSync._Sync_Impl.is_async: - table_2._register_instance_future.result() - assert len(client._instance_owners[instance_1_key]) == 2 - assert len(client._active_instances) == 1 - assert id(table_1) in client._instance_owners[instance_1_key] - assert id(table_2) in client._instance_owners[instance_1_key] - with client.get_table("instance_1", "table_3") as table_3: - assert table_3._register_instance_future is not None - if not CrossSync._Sync_Impl.is_async: - table_3._register_instance_future.result() - instance_3_path = client._gapic_client.instance_path( - client.project, "instance_1" - ) - instance_3_key = _WarmedInstanceKey( - instance_3_path, table_3.table_name, table_3.app_profile_id - ) - assert len(client._instance_owners[instance_1_key]) == 2 - assert len(client._instance_owners[instance_3_key]) == 1 - assert len(client._active_instances) == 2 - assert id(table_1) in client._instance_owners[instance_1_key] - assert id(table_2) in client._instance_owners[instance_1_key] - assert id(table_3) in client._instance_owners[instance_3_key] - assert len(client._active_instances) == 1 - assert instance_1_key in client._active_instances - assert id(table_2) not in client._instance_owners[instance_1_key] - assert len(client._active_instances) == 0 - assert instance_1_key not in client._active_instances - assert len(client._instance_owners[instance_1_key]) == 0 - - def test__multiple_instance_registration(self): - """registering with multiple instance keys should update the key - in instance_owners and active_instances""" - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - - with self._make_client(project="project-id") as client: - with client.get_table("instance_1", "table_1") as table_1: - assert table_1._register_instance_future is not None - if not CrossSync._Sync_Impl.is_async: - table_1._register_instance_future.result() - with client.get_table("instance_2", "table_2") as table_2: - assert table_2._register_instance_future is not None - if not CrossSync._Sync_Impl.is_async: - table_2._register_instance_future.result() - instance_1_path = client._gapic_client.instance_path( - client.project, "instance_1" - ) - instance_1_key = _WarmedInstanceKey( - instance_1_path, table_1.table_name, table_1.app_profile_id - ) - instance_2_path = client._gapic_client.instance_path( - client.project, "instance_2" - ) - instance_2_key = _WarmedInstanceKey( - instance_2_path, table_2.table_name, table_2.app_profile_id - ) - assert len(client._instance_owners[instance_1_key]) == 1 - assert len(client._instance_owners[instance_2_key]) == 1 - assert len(client._active_instances) == 2 - assert id(table_1) in client._instance_owners[instance_1_key] - assert id(table_2) in client._instance_owners[instance_2_key] - assert len(client._active_instances) == 1 - assert instance_1_key in client._active_instances - assert len(client._instance_owners[instance_2_key]) == 0 - assert len(client._instance_owners[instance_1_key]) == 1 - assert id(table_1) in client._instance_owners[instance_1_key] - assert len(client._active_instances) == 0 - assert len(client._instance_owners[instance_1_key]) == 0 - assert len(client._instance_owners[instance_2_key]) == 0 - - def test_get_table(self): - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - - client = self._make_client(project="project-id") - assert not client._active_instances - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - table = client.get_table( - expected_instance_id, expected_table_id, expected_app_profile_id - ) - CrossSync._Sync_Impl.yield_to_event_loop() - assert isinstance(table, TestTable._get_target_class()) - assert table.table_id == expected_table_id - assert ( - table.table_name - == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" - ) - assert table.instance_id == expected_instance_id - assert ( - table.instance_name - == f"projects/{client.project}/instances/{expected_instance_id}" - ) - assert table.app_profile_id == expected_app_profile_id - assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id - ) - assert instance_key in client._active_instances - assert client._instance_owners[instance_key] == {id(table)} - client.close() - - def test_get_table_arg_passthrough(self): - """All arguments passed in get_table should be sent to constructor""" - with self._make_client(project="project-id") as client: - with mock.patch.object( - TestTable._get_target_class(), "__init__" - ) as mock_constructor: - mock_constructor.return_value = None - assert not client._active_instances - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - expected_args = (1, "test", {"test": 2}) - expected_kwargs = {"hello": "world", "test": 2} - client.get_table( - expected_instance_id, - expected_table_id, - expected_app_profile_id, - *expected_args, - **expected_kwargs, - ) - mock_constructor.assert_called_once_with( - client, - expected_instance_id, - expected_table_id, - expected_app_profile_id, - *expected_args, - **expected_kwargs, - ) - - def test_get_table_context_manager(self): - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - expected_project_id = "project-id" - with mock.patch.object(TestTable._get_target_class(), "close") as close_mock: - with self._make_client(project=expected_project_id) as client: - with client.get_table( - expected_instance_id, expected_table_id, expected_app_profile_id - ) as table: - CrossSync._Sync_Impl.yield_to_event_loop() - assert isinstance(table, TestTable._get_target_class()) - assert table.table_id == expected_table_id - assert ( - table.table_name - == f"projects/{expected_project_id}/instances/{expected_instance_id}/tables/{expected_table_id}" - ) - assert table.instance_id == expected_instance_id - assert ( - table.instance_name - == f"projects/{expected_project_id}/instances/{expected_instance_id}" - ) - assert table.app_profile_id == expected_app_profile_id - assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id - ) - assert instance_key in client._active_instances - assert client._instance_owners[instance_key] == {id(table)} - assert close_mock.call_count == 1 - - def test_multiple_pool_sizes(self): - pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] - for pool_size in pool_sizes: - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - assert len(client._channel_refresh_tasks) == pool_size - client_duplicate = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - assert len(client_duplicate._channel_refresh_tasks) == pool_size - assert str(pool_size) in str(client.transport) - client.close() - client_duplicate.close() - - def test_close(self): - pool_size = 7 - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - assert len(client._channel_refresh_tasks) == pool_size - tasks_list = list(client._channel_refresh_tasks) - for task in client._channel_refresh_tasks: - assert not task.done() - with mock.patch.object( - PooledBigtableGrpcTransport, "close", CrossSync._Sync_Impl.Mock() - ) as close_mock: - client.close() - close_mock.assert_called_once() - if CrossSync._Sync_Impl.is_async: - close_mock.assert_awaited() - for task in tasks_list: - assert task.done() - - def test_close_with_timeout(self): - pool_size = 7 - expected_timeout = 19 - client = self._make_client(project="project-id", pool_size=pool_size) - tasks = list(client._channel_refresh_tasks) - with mock.patch.object( - CrossSync._Sync_Impl, "wait", CrossSync._Sync_Impl.Mock() - ) as wait_for_mock: - client.close(timeout=expected_timeout) - wait_for_mock.assert_called_once() - if CrossSync._Sync_Impl.is_async: - wait_for_mock.assert_awaited() - assert wait_for_mock.call_args[1]["timeout"] == expected_timeout - client._channel_refresh_tasks = tasks - client.close() - - def test_context_manager(self): - close_mock = CrossSync._Sync_Impl.Mock() - true_close = None - with self._make_client(project="project-id") as client: - true_close = client.close() - client.close = close_mock - for task in client._channel_refresh_tasks: - assert not task.done() - assert client.project == "project-id" - assert client._active_instances == set() - close_mock.assert_not_called() - close_mock.assert_called_once() - if CrossSync._Sync_Impl.is_async: - close_mock.assert_awaited() - true_close - - -class TestTable: - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - @staticmethod - def _get_target_class(): - return Table - - def test_table_ctor(self): - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - - expected_table_id = "table-id" - expected_instance_id = "instance-id" - expected_app_profile_id = "app-profile-id" - expected_operation_timeout = 123 - expected_attempt_timeout = 12 - expected_read_rows_operation_timeout = 1.5 - expected_read_rows_attempt_timeout = 0.5 - expected_mutate_rows_operation_timeout = 2.5 - expected_mutate_rows_attempt_timeout = 0.75 - client = self._make_client() - assert not client._active_instances - table = self._get_target_class()( - client, - expected_instance_id, - expected_table_id, - expected_app_profile_id, - default_operation_timeout=expected_operation_timeout, - default_attempt_timeout=expected_attempt_timeout, - default_read_rows_operation_timeout=expected_read_rows_operation_timeout, - default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout, - default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, - default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, - ) - CrossSync._Sync_Impl.yield_to_event_loop() - assert table.table_id == expected_table_id - assert table.instance_id == expected_instance_id - assert table.app_profile_id == expected_app_profile_id - assert table.client is client - instance_key = _WarmedInstanceKey( - table.instance_name, table.table_name, table.app_profile_id - ) - assert instance_key in client._active_instances - assert client._instance_owners[instance_key] == {id(table)} - assert table.default_operation_timeout == expected_operation_timeout - assert table.default_attempt_timeout == expected_attempt_timeout - assert ( - table.default_read_rows_operation_timeout - == expected_read_rows_operation_timeout - ) - assert ( - table.default_read_rows_attempt_timeout - == expected_read_rows_attempt_timeout - ) - assert ( - table.default_mutate_rows_operation_timeout - == expected_mutate_rows_operation_timeout - ) - assert ( - table.default_mutate_rows_attempt_timeout - == expected_mutate_rows_attempt_timeout - ) - table._register_instance_future - assert table._register_instance_future.done() - assert not table._register_instance_future.cancelled() - assert table._register_instance_future.exception() is None - client.close() - - def test_table_ctor_defaults(self): - """should provide default timeout values and app_profile_id""" - expected_table_id = "table-id" - expected_instance_id = "instance-id" - client = self._make_client() - assert not client._active_instances - table = self._get_target_class()( - client, expected_instance_id, expected_table_id - ) - CrossSync._Sync_Impl.yield_to_event_loop() - assert table.table_id == expected_table_id - assert table.instance_id == expected_instance_id - assert table.app_profile_id is None - assert table.client is client - assert table.default_operation_timeout == 60 - assert table.default_read_rows_operation_timeout == 600 - assert table.default_mutate_rows_operation_timeout == 600 - assert table.default_attempt_timeout == 20 - assert table.default_read_rows_attempt_timeout == 20 - assert table.default_mutate_rows_attempt_timeout == 60 - client.close() - - def test_table_ctor_invalid_timeout_values(self): - """bad timeout values should raise ValueError""" - client = self._make_client() - timeout_pairs = [ - ("default_operation_timeout", "default_attempt_timeout"), - ( - "default_read_rows_operation_timeout", - "default_read_rows_attempt_timeout", - ), - ( - "default_mutate_rows_operation_timeout", - "default_mutate_rows_attempt_timeout", - ), - ] - for operation_timeout, attempt_timeout in timeout_pairs: - with pytest.raises(ValueError) as e: - self._get_target_class()(client, "", "", **{attempt_timeout: -1}) - assert "attempt_timeout must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - self._get_target_class()(client, "", "", **{operation_timeout: -1}) - assert "operation_timeout must be greater than 0" in str(e.value) - client.close() - - @pytest.mark.parametrize( - "fn_name,fn_args,is_stream,extra_retryables", - [ - ("read_rows_stream", (ReadRowsQuery(),), True, ()), - ("read_rows", (ReadRowsQuery(),), True, ()), - ("read_row", (b"row_key",), True, ()), - ("read_rows_sharded", ([ReadRowsQuery()],), True, ()), - ("row_exists", (b"row_key",), True, ()), - ("sample_row_keys", (), False, ()), - ("mutate_row", (b"row_key", [mock.Mock()]), False, ()), - ( - "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), - False, - (_MutateRowsIncomplete,), - ), - ], - ) - @pytest.mark.parametrize( - "input_retryables,expected_retryables", - [ - ( - TABLE_DEFAULT.READ_ROWS, - [ - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - core_exceptions.Aborted, - ], - ), - ( - TABLE_DEFAULT.DEFAULT, - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ), - ( - TABLE_DEFAULT.MUTATE_ROWS, - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ), - ([], []), - ([4], [core_exceptions.DeadlineExceeded]), - ], - ) - def test_customizable_retryable_errors( - self, - input_retryables, - expected_retryables, - fn_name, - fn_args, - is_stream, - extra_retryables, - ): - """Test that retryable functions support user-configurable arguments, and that the configured retryables are passed - down to the gapic layer.""" - retry_fn = "retry_target" - if is_stream: - retry_fn += "_stream" - if CrossSync._Sync_Impl.is_async: - retry_fn = f"CrossSync.{retry_fn}" - else: - retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" - with mock.patch( - f"google.cloud.bigtable.data._sync.cross_sync.{retry_fn}" - ) as retry_fn_mock: - with self._make_client() as client: - table = client.get_table("instance-id", "table-id") - expected_predicate = expected_retryables.__contains__ - retry_fn_mock.side_effect = RuntimeError("stop early") - with mock.patch( - "google.api_core.retry.if_exception_type" - ) as predicate_builder_mock: - predicate_builder_mock.return_value = expected_predicate - with pytest.raises(Exception): - test_fn = table.__getattribute__(fn_name) - test_fn(*fn_args, retryable_errors=input_retryables) - predicate_builder_mock.assert_called_once_with( - *expected_retryables, *extra_retryables - ) - retry_call_args = retry_fn_mock.call_args_list[0].args - assert retry_call_args[1] is expected_predicate - - @pytest.mark.parametrize( - "fn_name,fn_args,gapic_fn", - [ - ("read_rows_stream", (ReadRowsQuery(),), "read_rows"), - ("read_rows", (ReadRowsQuery(),), "read_rows"), - ("read_row", (b"row_key",), "read_rows"), - ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), - ("row_exists", (b"row_key",), "read_rows"), - ("sample_row_keys", (), "sample_row_keys"), - ("mutate_row", (b"row_key", [mock.Mock()]), "mutate_row"), - ( - "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), - "mutate_rows", - ), - ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), - ( - "read_modify_write_row", - (b"row_key", mock.Mock()), - "read_modify_write_row", - ), - ], - ) - @pytest.mark.parametrize("include_app_profile", [True, False]) - def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): - """check that all requests attach proper metadata headers""" - profile = "profile" if include_app_profile else None - with mock.patch.object( - BigtableClient, gapic_fn, CrossSync._Sync_Impl.Mock() - ) as gapic_mock: - gapic_mock.side_effect = RuntimeError("stop early") - with self._make_client() as client: - table = self._get_target_class()( - client, "instance-id", "table-id", profile - ) - try: - test_fn = table.__getattribute__(fn_name) - maybe_stream = test_fn(*fn_args) - [i for i in maybe_stream] - except Exception: - pass - kwargs = gapic_mock.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata - - -class TestReadRows: - """ - Tests for table.read_rows and related methods. - """ - - @staticmethod - def _get_operation_class(): - return _ReadRowsOperation - - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - def _make_table(self, *args, **kwargs): - client_mock = mock.Mock() - client_mock._register_instance.side_effect = ( - lambda *args, **kwargs: CrossSync._Sync_Impl.yield_to_event_loop() - ) - client_mock._remove_instance_registration.side_effect = ( - lambda *args, **kwargs: CrossSync._Sync_Impl.yield_to_event_loop() - ) - kwargs["instance_id"] = kwargs.get( - "instance_id", args[0] if args else "instance" - ) - kwargs["table_id"] = kwargs.get( - "table_id", args[1] if len(args) > 1 else "table" - ) - client_mock._gapic_client.table_path.return_value = kwargs["table_id"] - client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] - return TestTable._get_target_class()(client_mock, *args, **kwargs) - - def _make_stats(self): - from google.cloud.bigtable_v2.types import RequestStats - from google.cloud.bigtable_v2.types import FullReadStatsView - from google.cloud.bigtable_v2.types import ReadIterationStats - - return RequestStats( - full_read_stats_view=FullReadStatsView( - read_iteration_stats=ReadIterationStats( - rows_seen_count=1, - rows_returned_count=2, - cells_seen_count=3, - cells_returned_count=4, - ) - ) - ) - - @staticmethod - def _make_chunk(*args, **kwargs): - from google.cloud.bigtable_v2 import ReadRowsResponse - - kwargs["row_key"] = kwargs.get("row_key", b"row_key") - kwargs["family_name"] = kwargs.get("family_name", "family_name") - kwargs["qualifier"] = kwargs.get("qualifier", b"qualifier") - kwargs["value"] = kwargs.get("value", b"value") - kwargs["commit_row"] = kwargs.get("commit_row", True) - return ReadRowsResponse.CellChunk(*args, **kwargs) - - @staticmethod - def _make_gapic_stream( - chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0 - ): - from google.cloud.bigtable_v2 import ReadRowsResponse - - class mock_stream: - def __init__(self, chunk_list, sleep_time): - self.chunk_list = chunk_list - self.idx = -1 - self.sleep_time = sleep_time - - def __aiter__(self): - return self - - def __iter__(self): - return self - - def __anext__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - if sleep_time: - CrossSync._Sync_Impl.sleep(self.sleep_time) - chunk = self.chunk_list[self.idx] - if isinstance(chunk, Exception): - raise chunk - else: - return ReadRowsResponse(chunks=[chunk]) - raise CrossSync._Sync_Impl.StopIteration - - def __next__(self): - return self.__anext__() - - def cancel(self): - pass - - return mock_stream(chunk_list, sleep_time) - - def execute_fn(self, table, *args, **kwargs): - return table.read_rows(*args, **kwargs) - - def test_read_rows(self): - query = ReadRowsQuery() - chunks = [ - self._make_chunk(row_key=b"test_1"), - self._make_chunk(row_key=b"test_2"), - ] - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks - ) - results = self.execute_fn(table, query, operation_timeout=3) - assert len(results) == 2 - assert results[0].row_key == b"test_1" - assert results[1].row_key == b"test_2" - - def test_read_rows_stream(self): - query = ReadRowsQuery() - chunks = [ - self._make_chunk(row_key=b"test_1"), - self._make_chunk(row_key=b"test_2"), - ] - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks - ) - gen = table.read_rows_stream(query, operation_timeout=3) - results = [row for row in gen] - assert len(results) == 2 - assert results[0].row_key == b"test_1" - assert results[1].row_key == b"test_2" - - @pytest.mark.parametrize("include_app_profile", [True, False]) - def test_read_rows_query_matches_request(self, include_app_profile): - from google.cloud.bigtable.data import RowRange - from google.cloud.bigtable.data.row_filters import PassAllFilter - - app_profile_id = "app_profile_id" if include_app_profile else None - with self._make_table(app_profile_id=app_profile_id) as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream([]) - row_keys = [b"test_1", "test_2"] - row_ranges = RowRange("1start", "2end") - filter_ = PassAllFilter(True) - limit = 99 - query = ReadRowsQuery( - row_keys=row_keys, - row_ranges=row_ranges, - row_filter=filter_, - limit=limit, - ) - results = table.read_rows(query, operation_timeout=3) - assert len(results) == 0 - call_request = read_rows.call_args_list[0][0][0] - query_pb = query._to_pb(table) - assert call_request == query_pb - - @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) - def test_read_rows_timeout(self, operation_timeout): - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - query = ReadRowsQuery() - chunks = [self._make_chunk(row_key=b"test_1")] - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=0.15 - ) - try: - table.read_rows(query, operation_timeout=operation_timeout) - except core_exceptions.DeadlineExceeded as e: - assert ( - e.message - == f"operation_timeout of {operation_timeout:0.1f}s exceeded" - ) - - @pytest.mark.parametrize( - "per_request_t, operation_t, expected_num", - [(0.05, 0.08, 2), (0.05, 0.14, 3), (0.05, 0.24, 5)], - ) - def test_read_rows_attempt_timeout(self, per_request_t, operation_t, expected_num): - """Ensures that the attempt_timeout is respected and that the number of - requests is as expected. - - operation_timeout does not cancel the request, so we expect the number of - requests to be the ceiling of operation_timeout / attempt_timeout.""" - from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - - expected_last_timeout = operation_t - (expected_num - 1) * per_request_t - with mock.patch("random.uniform", side_effect=lambda a, b: 0): - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=per_request_t - ) - query = ReadRowsQuery() - chunks = [core_exceptions.DeadlineExceeded("mock deadline")] - try: - table.read_rows( - query, - operation_timeout=operation_t, - attempt_timeout=per_request_t, - ) - except core_exceptions.DeadlineExceeded as e: - retry_exc = e.__cause__ - if expected_num == 0: - assert retry_exc is None - else: - assert type(retry_exc) is RetryExceptionGroup - assert f"{expected_num} failed attempts" in str(retry_exc) - assert len(retry_exc.exceptions) == expected_num - for sub_exc in retry_exc.exceptions: - assert sub_exc.message == "mock deadline" - assert read_rows.call_count == expected_num - for _, call_kwargs in read_rows.call_args_list[:-1]: - assert call_kwargs["timeout"] == per_request_t - assert call_kwargs["retry"] is None - assert ( - abs( - read_rows.call_args_list[-1][1]["timeout"] - - expected_last_timeout - ) - < 0.05 - ) - - @pytest.mark.parametrize( - "exc_type", - [ - core_exceptions.Aborted, - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - ], - ) - def test_read_rows_retryable_error(self, exc_type): - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - [expected_error] - ) - query = ReadRowsQuery() - expected_error = exc_type("mock error") - try: - table.read_rows(query, operation_timeout=0.1) - except core_exceptions.DeadlineExceeded as e: - retry_exc = e.__cause__ - root_cause = retry_exc.exceptions[0] - assert type(root_cause) is exc_type - assert root_cause == expected_error - - @pytest.mark.parametrize( - "exc_type", - [ - core_exceptions.Cancelled, - core_exceptions.PreconditionFailed, - core_exceptions.NotFound, - core_exceptions.PermissionDenied, - core_exceptions.Conflict, - core_exceptions.InternalServerError, - core_exceptions.TooManyRequests, - core_exceptions.ResourceExhausted, - InvalidChunk, - ], - ) - def test_read_rows_non_retryable_error(self, exc_type): - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - [expected_error] - ) - query = ReadRowsQuery() - expected_error = exc_type("mock error") - try: - table.read_rows(query, operation_timeout=0.1) - except exc_type as e: - assert e == expected_error - - def test_read_rows_revise_request(self): - """Ensure that _revise_request is called between retries""" - from google.cloud.bigtable.data.exceptions import InvalidChunk - from google.cloud.bigtable_v2.types import RowSet - - return_val = RowSet() - with mock.patch.object( - self._get_operation_class(), "_revise_request_rowset" - ) as revise_rowset: - revise_rowset.return_value = return_val - with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks - ) - row_keys = [b"test_1", b"test_2", b"test_3"] - query = ReadRowsQuery(row_keys=row_keys) - chunks = [ - self._make_chunk(row_key=b"test_1"), - core_exceptions.Aborted("mock retryable error"), - ] - try: - table.read_rows(query) - except InvalidChunk: - revise_rowset.assert_called() - first_call_kwargs = revise_rowset.call_args_list[0].kwargs - assert first_call_kwargs["row_set"] == query._to_pb(table).rows - assert first_call_kwargs["last_seen_row_key"] == b"test_1" - revised_call = read_rows.call_args_list[1].args[0] - assert revised_call.rows == return_val - - def test_read_rows_default_timeouts(self): - """Ensure that the default timeouts are set on the read rows operation when not overridden""" - operation_timeout = 8 - attempt_timeout = 4 - with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: - mock_op.side_effect = RuntimeError("mock error") - with self._make_table( - default_read_rows_operation_timeout=operation_timeout, - default_read_rows_attempt_timeout=attempt_timeout, - ) as table: - try: - table.read_rows(ReadRowsQuery()) - except RuntimeError: - pass - kwargs = mock_op.call_args_list[0].kwargs - assert kwargs["operation_timeout"] == operation_timeout - assert kwargs["attempt_timeout"] == attempt_timeout - - def test_read_rows_default_timeout_override(self): - """When timeouts are passed, they overwrite default values""" - operation_timeout = 8 - attempt_timeout = 4 - with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: - mock_op.side_effect = RuntimeError("mock error") - with self._make_table( - default_operation_timeout=99, default_attempt_timeout=97 - ) as table: - try: - table.read_rows( - ReadRowsQuery(), - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - ) - except RuntimeError: - pass - kwargs = mock_op.call_args_list[0].kwargs - assert kwargs["operation_timeout"] == operation_timeout - assert kwargs["attempt_timeout"] == attempt_timeout - - def test_read_row(self): - """Test reading a single row""" - with self._make_client() as client: - table = client.get_table("instance", "table") - row_key = b"test_1" - with mock.patch.object(table, "read_rows") as read_rows: - expected_result = object() - read_rows.side_effect = lambda *args, **kwargs: [expected_result] - expected_op_timeout = 8 - expected_req_timeout = 4 - row = table.read_row( - row_key, - operation_timeout=expected_op_timeout, - attempt_timeout=expected_req_timeout, - ) - assert row == expected_result - assert read_rows.call_count == 1 - args, kwargs = read_rows.call_args_list[0] - assert kwargs["operation_timeout"] == expected_op_timeout - assert kwargs["attempt_timeout"] == expected_req_timeout - assert len(args) == 1 - assert isinstance(args[0], ReadRowsQuery) - query = args[0] - assert query.row_keys == [row_key] - assert query.row_ranges == [] - assert query.limit == 1 - - def test_read_row_w_filter(self): - """Test reading a single row with an added filter""" - with self._make_client() as client: - table = client.get_table("instance", "table") - row_key = b"test_1" - with mock.patch.object(table, "read_rows") as read_rows: - expected_result = object() - read_rows.side_effect = lambda *args, **kwargs: [expected_result] - expected_op_timeout = 8 - expected_req_timeout = 4 - mock_filter = mock.Mock() - expected_filter = {"filter": "mock filter"} - mock_filter._to_dict.return_value = expected_filter - row = table.read_row( - row_key, - operation_timeout=expected_op_timeout, - attempt_timeout=expected_req_timeout, - row_filter=expected_filter, - ) - assert row == expected_result - assert read_rows.call_count == 1 - args, kwargs = read_rows.call_args_list[0] - assert kwargs["operation_timeout"] == expected_op_timeout - assert kwargs["attempt_timeout"] == expected_req_timeout - assert len(args) == 1 - assert isinstance(args[0], ReadRowsQuery) - query = args[0] - assert query.row_keys == [row_key] - assert query.row_ranges == [] - assert query.limit == 1 - assert query.filter == expected_filter - - def test_read_row_no_response(self): - """should return None if row does not exist""" - with self._make_client() as client: - table = client.get_table("instance", "table") - row_key = b"test_1" - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = lambda *args, **kwargs: [] - expected_op_timeout = 8 - expected_req_timeout = 4 - result = table.read_row( - row_key, - operation_timeout=expected_op_timeout, - attempt_timeout=expected_req_timeout, - ) - assert result is None - assert read_rows.call_count == 1 - args, kwargs = read_rows.call_args_list[0] - assert kwargs["operation_timeout"] == expected_op_timeout - assert kwargs["attempt_timeout"] == expected_req_timeout - assert isinstance(args[0], ReadRowsQuery) - query = args[0] - assert query.row_keys == [row_key] - assert query.row_ranges == [] - assert query.limit == 1 - - @pytest.mark.parametrize( - "return_value,expected_result", - [([], False), ([object()], True), ([object(), object()], True)], - ) - def test_row_exists(self, return_value, expected_result): - """Test checking for row existence""" - with self._make_client() as client: - table = client.get_table("instance", "table") - row_key = b"test_1" - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = lambda *args, **kwargs: return_value - expected_op_timeout = 1 - expected_req_timeout = 2 - result = table.row_exists( - row_key, - operation_timeout=expected_op_timeout, - attempt_timeout=expected_req_timeout, - ) - assert expected_result == result - assert read_rows.call_count == 1 - args, kwargs = read_rows.call_args_list[0] - assert kwargs["operation_timeout"] == expected_op_timeout - assert kwargs["attempt_timeout"] == expected_req_timeout - assert isinstance(args[0], ReadRowsQuery) - expected_filter = { - "chain": { - "filters": [ - {"cells_per_row_limit_filter": 1}, - {"strip_value_transformer": True}, - ] - } - } - query = args[0] - assert query.row_keys == [row_key] - assert query.row_ranges == [] - assert query.limit == 1 - assert query.filter._to_dict() == expected_filter - - -class TestReadRowsSharded: - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - def test_read_rows_sharded_empty_query(self): - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as exc: - table.read_rows_sharded([]) - assert "empty sharded_query" in str(exc.value) - - def test_read_rows_sharded_multiple_queries(self): - """Test with multiple queries. Should return results from both""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - table.client._gapic_client, "read_rows" - ) as read_rows: - read_rows.side_effect = ( - lambda *args, **kwargs: TestReadRows._make_gapic_stream( - [ - TestReadRows._make_chunk(row_key=k) - for k in args[0].rows.row_keys - ] - ) - ) - query_1 = ReadRowsQuery(b"test_1") - query_2 = ReadRowsQuery(b"test_2") - result = table.read_rows_sharded([query_1, query_2]) - assert len(result) == 2 - assert result[0].row_key == b"test_1" - assert result[1].row_key == b"test_2" - - @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24]) - def test_read_rows_sharded_multiple_queries_calls(self, n_queries): - """Each query should trigger a separate read_rows call""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - query_list = [ReadRowsQuery() for _ in range(n_queries)] - table.read_rows_sharded(query_list) - assert read_rows.call_count == n_queries - - def test_read_rows_sharded_errors(self): - """Errors should be exposed as ShardedReadRowsExceptionGroups""" - from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup - from google.cloud.bigtable.data.exceptions import FailedQueryShardError - - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = RuntimeError("mock error") - query_1 = ReadRowsQuery(b"test_1") - query_2 = ReadRowsQuery(b"test_2") - with pytest.raises(ShardedReadRowsExceptionGroup) as exc: - table.read_rows_sharded([query_1, query_2]) - exc_group = exc.value - assert isinstance(exc_group, ShardedReadRowsExceptionGroup) - assert len(exc.value.exceptions) == 2 - assert isinstance(exc.value.exceptions[0], FailedQueryShardError) - assert isinstance(exc.value.exceptions[0].__cause__, RuntimeError) - assert exc.value.exceptions[0].index == 0 - assert exc.value.exceptions[0].query == query_1 - assert isinstance(exc.value.exceptions[1], FailedQueryShardError) - assert isinstance(exc.value.exceptions[1].__cause__, RuntimeError) - assert exc.value.exceptions[1].index == 1 - assert exc.value.exceptions[1].query == query_2 - - def test_read_rows_sharded_concurrent(self): - """Ensure sharded requests are concurrent""" - import time - - def mock_call(*args, **kwargs): - asyncio.sleep(0.1) - return [mock.Mock()] - - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = mock_call - queries = [ReadRowsQuery() for _ in range(10)] - start_time = time.monotonic() - result = table.read_rows_sharded(queries) - call_time = time.monotonic() - start_time - assert read_rows.call_count == 10 - assert len(result) == 10 - assert call_time < 0.2 - - def test_read_rows_sharded_concurrency_limit(self): - """Only 10 queries should be processed concurrently. Others should be queued - - Should start a new query as soon as previous finishes""" - from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT - - assert _CONCURRENCY_LIMIT == 10 - num_queries = 15 - increment_time = 0.05 - max_time = increment_time * (_CONCURRENCY_LIMIT - 1) - rpc_times = [min(i * increment_time, max_time) for i in range(num_queries)] - - def mock_call(*args, **kwargs): - next_sleep = rpc_times.pop(0) - asyncio.sleep(next_sleep) - return [mock.Mock()] - - starting_timeout = 10 - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = mock_call - queries = [ReadRowsQuery() for _ in range(num_queries)] - table.read_rows_sharded(queries, operation_timeout=starting_timeout) - assert read_rows.call_count == num_queries - rpc_start_list = [ - starting_timeout - kwargs["operation_timeout"] - for _, kwargs in read_rows.call_args_list - ] - eps = 0.01 - assert all( - (rpc_start_list[i] < eps for i in range(_CONCURRENCY_LIMIT)) - ) - for i in range(num_queries - _CONCURRENCY_LIMIT): - idx = i + _CONCURRENCY_LIMIT - assert rpc_start_list[idx] - i * increment_time < eps - - def test_read_rows_sharded_expirary(self): - """If the operation times out before all shards complete, should raise - a ShardedReadRowsExceptionGroup""" - from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT - from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup - from google.api_core.exceptions import DeadlineExceeded - - operation_timeout = 0.1 - num_queries = 15 - sleeps = [0] * _CONCURRENCY_LIMIT + [DeadlineExceeded("times up")] * ( - num_queries - _CONCURRENCY_LIMIT - ) - - def mock_call(*args, **kwargs): - next_item = sleeps.pop(0) - if isinstance(next_item, Exception): - raise next_item - else: - asyncio.sleep(next_item) - return [mock.Mock()] - - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = mock_call - queries = [ReadRowsQuery() for _ in range(num_queries)] - with pytest.raises(ShardedReadRowsExceptionGroup) as exc: - table.read_rows_sharded( - queries, operation_timeout=operation_timeout - ) - assert isinstance(exc.value, ShardedReadRowsExceptionGroup) - assert len(exc.value.exceptions) == num_queries - _CONCURRENCY_LIMIT - assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT - - def test_read_rows_sharded_negative_batch_timeout(self): - """try to run with batch that starts after operation timeout - - They should raise DeadlineExceeded errors""" - from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup - from google.api_core.exceptions import DeadlineExceeded - - def mock_call(*args, **kwargs): - CrossSync._Sync_Impl.sleep(0.05) - return [mock.Mock()] - - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object(table, "read_rows") as read_rows: - read_rows.side_effect = mock_call - queries = [ReadRowsQuery() for _ in range(15)] - with pytest.raises(ShardedReadRowsExceptionGroup) as exc: - table.read_rows_sharded(queries, operation_timeout=0.01) - assert isinstance(exc.value, ShardedReadRowsExceptionGroup) - assert len(exc.value.exceptions) == 5 - assert all( - ( - isinstance(e.__cause__, DeadlineExceeded) - for e in exc.value.exceptions - ) - ) - - -class TestSampleRowKeys: - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): - from google.cloud.bigtable_v2.types import SampleRowKeysResponse - - for value in sample_list: - yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) - - def test_sample_row_keys(self): - """Test that method returns the expected key samples""" - samples = [(b"test_1", 0), (b"test_2", 100), (b"test_3", 200)] - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - table.client._gapic_client, - "sample_row_keys", - CrossSync._Sync_Impl.Mock(), - ) as sample_row_keys: - sample_row_keys.return_value = self._make_gapic_stream(samples) - result = table.sample_row_keys() - assert len(result) == 3 - assert all((isinstance(r, tuple) for r in result)) - assert all((isinstance(r[0], bytes) for r in result)) - assert all((isinstance(r[1], int) for r in result)) - assert result[0] == samples[0] - assert result[1] == samples[1] - assert result[2] == samples[2] - - def test_sample_row_keys_bad_timeout(self): - """should raise error if timeout is negative""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as e: - table.sample_row_keys(operation_timeout=-1) - assert "operation_timeout must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - table.sample_row_keys(attempt_timeout=-1) - assert "attempt_timeout must be greater than 0" in str(e.value) - - def test_sample_row_keys_default_timeout(self): - """Should fallback to using table default operation_timeout""" - expected_timeout = 99 - with self._make_client() as client: - with client.get_table( - "i", - "t", - default_operation_timeout=expected_timeout, - default_attempt_timeout=expected_timeout, - ) as table: - with mock.patch.object( - table.client._gapic_client, - "sample_row_keys", - CrossSync._Sync_Impl.Mock(), - ) as sample_row_keys: - sample_row_keys.return_value = self._make_gapic_stream([]) - result = table.sample_row_keys() - _, kwargs = sample_row_keys.call_args - assert abs(kwargs["timeout"] - expected_timeout) < 0.1 - assert result == [] - assert kwargs["retry"] is None - - def test_sample_row_keys_gapic_params(self): - """make sure arguments are propagated to gapic call as expected""" - expected_timeout = 10 - expected_profile = "test1" - instance = "instance_name" - table_id = "my_table" - with self._make_client() as client: - with client.get_table( - instance, table_id, app_profile_id=expected_profile - ) as table: - with mock.patch.object( - table.client._gapic_client, - "sample_row_keys", - CrossSync._Sync_Impl.Mock(), - ) as sample_row_keys: - sample_row_keys.return_value = self._make_gapic_stream([]) - table.sample_row_keys(attempt_timeout=expected_timeout) - args, kwargs = sample_row_keys.call_args - assert len(args) == 0 - assert len(kwargs) == 5 - assert kwargs["timeout"] == expected_timeout - assert kwargs["app_profile_id"] == expected_profile - assert kwargs["table_name"] == table.table_name - assert kwargs["metadata"] is not None - assert kwargs["retry"] is None - - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_sample_row_keys_retryable_errors(self, retryable_exception): - """retryable errors should be retried until timeout""" - from google.api_core.exceptions import DeadlineExceeded - from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - table.client._gapic_client, - "sample_row_keys", - CrossSync._Sync_Impl.Mock(), - ) as sample_row_keys: - sample_row_keys.side_effect = retryable_exception("mock") - with pytest.raises(DeadlineExceeded) as e: - table.sample_row_keys(operation_timeout=0.05) - cause = e.value.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert len(cause.exceptions) > 0 - assert isinstance(cause.exceptions[0], retryable_exception) - - @pytest.mark.parametrize( - "non_retryable_exception", - [ - core_exceptions.OutOfRange, - core_exceptions.NotFound, - core_exceptions.FailedPrecondition, - RuntimeError, - ValueError, - core_exceptions.Aborted, - ], - ) - def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): - """non-retryable errors should cause a raise""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - table.client._gapic_client, - "sample_row_keys", - CrossSync._Sync_Impl.Mock(), - ) as sample_row_keys: - sample_row_keys.side_effect = non_retryable_exception("mock") - with pytest.raises(non_retryable_exception): - table.sample_row_keys() - - -class TestMutateRow: - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - @pytest.mark.parametrize( - "mutation_arg", - [ - mutations.SetCell("family", b"qualifier", b"value"), - mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=1234567890 - ), - mutations.DeleteRangeFromColumn("family", b"qualifier"), - mutations.DeleteAllFromFamily("family"), - mutations.DeleteAllFromRow(), - [mutations.SetCell("family", b"qualifier", b"value")], - [ - mutations.DeleteRangeFromColumn("family", b"qualifier"), - mutations.DeleteAllFromRow(), - ], - ], - ) - def test_mutate_row(self, mutation_arg): - """Test mutations with no errors""" - expected_attempt_timeout = 19 - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.return_value = None - table.mutate_row( - "row_key", - mutation_arg, - attempt_timeout=expected_attempt_timeout, - ) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0].kwargs - assert ( - kwargs["table_name"] - == "projects/project/instances/instance/tables/table" - ) - assert kwargs["row_key"] == b"row_key" - formatted_mutations = ( - [mutation._to_pb() for mutation in mutation_arg] - if isinstance(mutation_arg, list) - else [mutation_arg._to_pb()] - ) - assert kwargs["mutations"] == formatted_mutations - assert kwargs["timeout"] == expected_attempt_timeout - assert kwargs["retry"] is None - - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_mutate_row_retryable_errors(self, retryable_exception): - from google.api_core.exceptions import DeadlineExceeded - from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(DeadlineExceeded) as e: - mutation = mutations.DeleteAllFromRow() - assert mutation.is_idempotent() is True - table.mutate_row("row_key", mutation, operation_timeout=0.01) - cause = e.value.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert isinstance(cause.exceptions[0], retryable_exception) - - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_mutate_row_non_idempotent_retryable_errors(self, retryable_exception): - """Non-idempotent mutations should not be retried""" - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(retryable_exception): - mutation = mutations.SetCell( - "family", b"qualifier", b"value", -1 - ) - assert mutation.is_idempotent() is False - table.mutate_row("row_key", mutation, operation_timeout=0.2) - - @pytest.mark.parametrize( - "non_retryable_exception", - [ - core_exceptions.OutOfRange, - core_exceptions.NotFound, - core_exceptions.FailedPrecondition, - RuntimeError, - ValueError, - core_exceptions.Aborted, - ], - ) - def test_mutate_row_non_retryable_errors(self, non_retryable_exception): - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = non_retryable_exception("mock") - with pytest.raises(non_retryable_exception): - mutation = mutations.SetCell( - "family", - b"qualifier", - b"value", - timestamp_micros=1234567890, - ) - assert mutation.is_idempotent() is True - table.mutate_row("row_key", mutation, operation_timeout=0.2) - - @pytest.mark.parametrize("include_app_profile", [True, False]) - def test_mutate_row_metadata(self, include_app_profile): - """request should attach metadata headers""" - profile = "profile" if include_app_profile else None - with self._make_client() as client: - with client.get_table("i", "t", app_profile_id=profile) as table: - with mock.patch.object( - client._gapic_client, "mutate_row", CrossSync._Sync_Impl.Mock() - ) as read_rows: - table.mutate_row("rk", mock.Mock()) - kwargs = read_rows.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata - - @pytest.mark.parametrize("mutations", [[], None]) - def test_mutate_row_no_mutations(self, mutations): - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as e: - table.mutate_row("key", mutations=mutations) - assert e.value.args[0] == "No mutations provided" - - -class TestBulkMutateRows: - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - def _mock_response(self, response_list): - from google.cloud.bigtable_v2.types import MutateRowsResponse - from google.rpc import status_pb2 - - statuses = [] - for response in response_list: - if isinstance(response, core_exceptions.GoogleAPICallError): - statuses.append( - status_pb2.Status( - message=str(response), code=response.grpc_status_code.value[0] - ) - ) - else: - statuses.append(status_pb2.Status(code=0)) - entries = [ - MutateRowsResponse.Entry(index=i, status=statuses[i]) - for i in range(len(response_list)) - ] - - def generator(): - yield MutateRowsResponse(entries=entries) - - return generator() - - @pytest.mark.parametrize( - "mutation_arg", - [ - [mutations.SetCell("family", b"qualifier", b"value")], - [ - mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=1234567890 - ) - ], - [mutations.DeleteRangeFromColumn("family", b"qualifier")], - [mutations.DeleteAllFromFamily("family")], - [mutations.DeleteAllFromRow()], - [mutations.SetCell("family", b"qualifier", b"value")], - [ - mutations.DeleteRangeFromColumn("family", b"qualifier"), - mutations.DeleteAllFromRow(), - ], - ], - ) - def test_bulk_mutate_rows(self, mutation_arg): - """Test mutations with no errors""" - expected_attempt_timeout = 19 - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.return_value = self._mock_response([None]) - bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg) - table.bulk_mutate_rows( - [bulk_mutation], attempt_timeout=expected_attempt_timeout - ) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args[1] - assert ( - kwargs["table_name"] - == "projects/project/instances/instance/tables/table" - ) - assert kwargs["entries"] == [bulk_mutation._to_pb()] - assert kwargs["timeout"] == expected_attempt_timeout - assert kwargs["retry"] is None - - def test_bulk_mutate_rows_multiple_entries(self): - """Test mutations with no errors""" - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.return_value = self._mock_response([None, None]) - mutation_list = [mutations.DeleteAllFromRow()] - entry_1 = mutations.RowMutationEntry(b"row_key_1", mutation_list) - entry_2 = mutations.RowMutationEntry(b"row_key_2", mutation_list) - table.bulk_mutate_rows([entry_1, entry_2]) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args[1] - assert ( - kwargs["table_name"] - == "projects/project/instances/instance/tables/table" - ) - assert kwargs["entries"][0] == entry_1._to_pb() - assert kwargs["entries"][1] == entry_2._to_pb() - - @pytest.mark.parametrize( - "exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_bulk_mutate_rows_idempotent_mutation_error_retryable(self, exception): - """Individual idempotent mutations should be retried if they fail with a retryable error""" - from google.cloud.bigtable.data.exceptions import ( - RetryExceptionGroup, - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = lambda *a, **k: self._mock_response( - [exception("mock")] - ) - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.DeleteAllFromRow() - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - table.bulk_mutate_rows([entry], operation_timeout=0.05) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert "non-idempotent" not in str(failed_exception) - assert isinstance(failed_exception, FailedMutationEntryError) - cause = failed_exception.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert isinstance(cause.exceptions[0], exception) - assert isinstance( - cause.exceptions[-1], core_exceptions.DeadlineExceeded - ) - - @pytest.mark.parametrize( - "exception", - [ - core_exceptions.OutOfRange, - core_exceptions.NotFound, - core_exceptions.FailedPrecondition, - core_exceptions.Aborted, - ], - ) - def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable(self, exception): - """Individual idempotent mutations should not be retried if they fail with a non-retryable error""" - from google.cloud.bigtable.data.exceptions import ( - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = lambda *a, **k: self._mock_response( - [exception("mock")] - ) - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.DeleteAllFromRow() - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - table.bulk_mutate_rows([entry], operation_timeout=0.05) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert "non-idempotent" not in str(failed_exception) - assert isinstance(failed_exception, FailedMutationEntryError) - cause = failed_exception.__cause__ - assert isinstance(cause, exception) - - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_bulk_mutate_idempotent_retryable_request_errors(self, retryable_exception): - """Individual idempotent mutations should be retried if the request fails with a retryable error""" - from google.cloud.bigtable.data.exceptions import ( - RetryExceptionGroup, - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - table.bulk_mutate_rows([entry], operation_timeout=0.05) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" not in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert isinstance(cause.exceptions[0], retryable_exception) - - @pytest.mark.parametrize( - "retryable_exception", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ) - def test_bulk_mutate_rows_non_idempotent_retryable_errors( - self, retryable_exception - ): - """Non-Idempotent mutations should never be retried""" - from google.cloud.bigtable.data.exceptions import ( - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = lambda *a, **k: self._mock_response( - [retryable_exception("mock")] - ) - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", -1 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is False - table.bulk_mutate_rows([entry], operation_timeout=0.2) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, retryable_exception) - - @pytest.mark.parametrize( - "non_retryable_exception", - [ - core_exceptions.OutOfRange, - core_exceptions.NotFound, - core_exceptions.FailedPrecondition, - RuntimeError, - ValueError, - ], - ) - def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception): - """If the request fails with a non-retryable error, mutations should not be retried""" - from google.cloud.bigtable.data.exceptions import ( - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = non_retryable_exception("mock") - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - table.bulk_mutate_rows([entry], operation_timeout=0.2) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" not in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, non_retryable_exception) - - def test_bulk_mutate_error_index(self): - """Test partial failure, partial success. Errors should be associated with the correct index""" - from google.api_core.exceptions import ( - DeadlineExceeded, - ServiceUnavailable, - FailedPrecondition, - ) - from google.cloud.bigtable.data.exceptions import ( - RetryExceptionGroup, - FailedMutationEntryError, - MutationsExceptionGroup, - ) - - with self._make_client(project="project") as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = [ - self._mock_response([None, ServiceUnavailable("mock"), None]), - self._mock_response([DeadlineExceeded("mock")]), - self._mock_response([FailedPrecondition("final")]), - ] - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entries = [ - mutations.RowMutationEntry( - f"row_key_{i}".encode(), [mutation] - ) - for i in range(3) - ] - assert mutation.is_idempotent() is True - table.bulk_mutate_rows(entries, operation_timeout=1000) - assert len(e.value.exceptions) == 1 - failed = e.value.exceptions[0] - assert isinstance(failed, FailedMutationEntryError) - assert failed.index == 1 - assert failed.entry == entries[1] - cause = failed.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert len(cause.exceptions) == 3 - assert isinstance(cause.exceptions[0], ServiceUnavailable) - assert isinstance(cause.exceptions[1], DeadlineExceeded) - assert isinstance(cause.exceptions[2], FailedPrecondition) - - def test_bulk_mutate_error_recovery(self): - """If an error occurs, then resolves, no exception should be raised""" - from google.api_core.exceptions import DeadlineExceeded - - with self._make_client(project="project") as client: - table = client.get_table("instance", "table") - with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: - mock_gapic.side_effect = [ - self._mock_response([DeadlineExceeded("mock")]), - self._mock_response([None]), - ] - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entries = [ - mutations.RowMutationEntry(f"row_key_{i}".encode(), [mutation]) - for i in range(3) - ] - table.bulk_mutate_rows(entries, operation_timeout=1000) - - -class TestCheckAndMutateRow: - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - @pytest.mark.parametrize("gapic_result", [True, False]) - def test_check_and_mutate(self, gapic_result): - from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - - app_profile = "app_profile_id" - with self._make_client() as client: - with client.get_table( - "instance", "table", app_profile_id=app_profile - ) as table: - with mock.patch.object( - client._gapic_client, "check_and_mutate_row" - ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=gapic_result - ) - row_key = b"row_key" - predicate = None - true_mutations = [mock.Mock()] - false_mutations = [mock.Mock(), mock.Mock()] - operation_timeout = 0.2 - found = table.check_and_mutate_row( - row_key, - predicate, - true_case_mutations=true_mutations, - false_case_mutations=false_mutations, - operation_timeout=operation_timeout, - ) - assert found == gapic_result - kwargs = mock_gapic.call_args[1] - assert kwargs["table_name"] == table.table_name - assert kwargs["row_key"] == row_key - assert kwargs["predicate_filter"] == predicate - assert kwargs["true_mutations"] == [ - m._to_pb() for m in true_mutations - ] - assert kwargs["false_mutations"] == [ - m._to_pb() for m in false_mutations - ] - assert kwargs["app_profile_id"] == app_profile - assert kwargs["timeout"] == operation_timeout - assert kwargs["retry"] is None - - def test_check_and_mutate_bad_timeout(self): - """Should raise error if operation_timeout < 0""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as e: - table.check_and_mutate_row( - b"row_key", - None, - true_case_mutations=[mock.Mock()], - false_case_mutations=[], - operation_timeout=-1, - ) - assert str(e.value) == "operation_timeout must be greater than 0" - - def test_check_and_mutate_single_mutations(self): - """if single mutations are passed, they should be internally wrapped in a list""" - from google.cloud.bigtable.data.mutations import SetCell - from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "check_and_mutate_row" - ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True - ) - true_mutation = SetCell("family", b"qualifier", b"value") - false_mutation = SetCell("family", b"qualifier", b"value") - table.check_and_mutate_row( - b"row_key", - None, - true_case_mutations=true_mutation, - false_case_mutations=false_mutation, - ) - kwargs = mock_gapic.call_args[1] - assert kwargs["true_mutations"] == [true_mutation._to_pb()] - assert kwargs["false_mutations"] == [false_mutation._to_pb()] - - def test_check_and_mutate_predicate_object(self): - """predicate filter should be passed to gapic request""" - from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - - mock_predicate = mock.Mock() - predicate_pb = {"predicate": "dict"} - mock_predicate._to_pb.return_value = predicate_pb - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "check_and_mutate_row" - ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True - ) - table.check_and_mutate_row( - b"row_key", mock_predicate, false_case_mutations=[mock.Mock()] - ) - kwargs = mock_gapic.call_args[1] - assert kwargs["predicate_filter"] == predicate_pb - assert mock_predicate._to_pb.call_count == 1 - assert kwargs["retry"] is None - - def test_check_and_mutate_mutations_parsing(self): - """mutations objects should be converted to protos""" - from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - from google.cloud.bigtable.data.mutations import DeleteAllFromRow - - mutations = [mock.Mock() for _ in range(5)] - for idx, mutation in enumerate(mutations): - mutation._to_pb.return_value = f"fake {idx}" - mutations.append(DeleteAllFromRow()) - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "check_and_mutate_row" - ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True - ) - table.check_and_mutate_row( - b"row_key", - None, - true_case_mutations=mutations[0:2], - false_case_mutations=mutations[2:], - ) - kwargs = mock_gapic.call_args[1] - assert kwargs["true_mutations"] == ["fake 0", "fake 1"] - assert kwargs["false_mutations"] == [ - "fake 2", - "fake 3", - "fake 4", - DeleteAllFromRow()._to_pb(), - ] - assert all( - (mutation._to_pb.call_count == 1 for mutation in mutations[:5]) - ) - - -class TestReadModifyWriteRow: - def _make_client(self, *args, **kwargs): - return TestBigtableDataClient._make_client(*args, **kwargs) - - @pytest.mark.parametrize( - "call_rules,expected_rules", - [ - ( - AppendValueRule("f", "c", b"1"), - [AppendValueRule("f", "c", b"1")._to_pb()], - ), - ( - [AppendValueRule("f", "c", b"1")], - [AppendValueRule("f", "c", b"1")._to_pb()], - ), - (IncrementRule("f", "c", 1), [IncrementRule("f", "c", 1)._to_pb()]), - ( - [AppendValueRule("f", "c", b"1"), IncrementRule("f", "c", 1)], - [ - AppendValueRule("f", "c", b"1")._to_pb(), - IncrementRule("f", "c", 1)._to_pb(), - ], - ), - ], - ) - def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): - """Test that the gapic call is called with given rules""" - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - table.read_modify_write_row("key", call_rules) - assert mock_gapic.call_count == 1 - found_kwargs = mock_gapic.call_args_list[0][1] - assert found_kwargs["rules"] == expected_rules - assert found_kwargs["retry"] is None - - @pytest.mark.parametrize("rules", [[], None]) - def test_read_modify_write_no_rules(self, rules): - with self._make_client() as client: - with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as e: - table.read_modify_write_row("key", rules=rules) - assert e.value.args[0] == "rules must contain at least one item" - - def test_read_modify_write_call_defaults(self): - instance = "instance1" - table_id = "table1" - project = "project1" - row_key = "row_key1" - with self._make_client(project=project) as client: - with client.get_table(instance, table_id) as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - table.read_modify_write_row(row_key, mock.Mock()) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0][1] - assert ( - kwargs["table_name"] - == f"projects/{project}/instances/{instance}/tables/{table_id}" - ) - assert kwargs["app_profile_id"] is None - assert kwargs["row_key"] == row_key.encode() - assert kwargs["timeout"] > 1 - - def test_read_modify_write_call_overrides(self): - row_key = b"row_key1" - expected_timeout = 12345 - profile_id = "profile1" - with self._make_client() as client: - with client.get_table( - "instance", "table_id", app_profile_id=profile_id - ) as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - table.read_modify_write_row( - row_key, mock.Mock(), operation_timeout=expected_timeout - ) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0][1] - assert kwargs["app_profile_id"] is profile_id - assert kwargs["row_key"] == row_key - assert kwargs["timeout"] == expected_timeout - - def test_read_modify_write_string_key(self): - row_key = "string_row_key1" - with self._make_client() as client: - with client.get_table("instance", "table_id") as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - table.read_modify_write_row(row_key, mock.Mock()) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0][1] - assert kwargs["row_key"] == row_key.encode() - - def test_read_modify_write_row_building(self): - """results from gapic call should be used to construct row""" - from google.cloud.bigtable.data.row import Row - from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse - from google.cloud.bigtable_v2.types import Row as RowPB - - mock_response = ReadModifyWriteRowResponse(row=RowPB()) - with self._make_client() as client: - with client.get_table("instance", "table_id") as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - with mock.patch.object(Row, "_from_pb") as constructor_mock: - mock_gapic.return_value = mock_response - table.read_modify_write_row("key", mock.Mock()) - assert constructor_mock.call_count == 1 - constructor_mock.assert_called_once_with(mock_response.row) diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py deleted file mode 100644 index f044e09aa..000000000 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ /dev/null @@ -1,1104 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is automatically generated by CrossSync. Do not edit manually. -import pytest -import asyncio -import time -import google.api_core.exceptions as core_exceptions -import google.api_core.retry -from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete -from google.cloud.bigtable.data import TABLE_DEFAULT -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -if CrossSync._Sync_Impl.is_async: - pass -else: - from google.cloud.bigtable.data._sync.client import Table - from google.cloud.bigtable.data._sync.mutations_batcher import ( - _FlowControl, - MutationsBatcher, - ) -try: - from unittest import mock -except ImportError: - import mock - - -class Test_FlowControl: - @staticmethod - def _target_class(): - return _FlowControl - - def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): - return self._target_class()(max_mutation_count, max_mutation_bytes) - - @staticmethod - def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation - - def test_ctor(self): - max_mutation_count = 9 - max_mutation_bytes = 19 - instance = self._make_one(max_mutation_count, max_mutation_bytes) - assert instance._max_mutation_count == max_mutation_count - assert instance._max_mutation_bytes == max_mutation_bytes - assert instance._in_flight_mutation_count == 0 - assert instance._in_flight_mutation_bytes == 0 - assert isinstance(instance._capacity_condition, CrossSync._Sync_Impl.Condition) - - def test_ctor_invalid_values(self): - """Test that values are positive, and fit within expected limits""" - with pytest.raises(ValueError) as e: - self._make_one(0, 1) - assert "max_mutation_count must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - self._make_one(1, 0) - assert "max_mutation_bytes must be greater than 0" in str(e.value) - - @pytest.mark.parametrize( - "max_count,max_size,existing_count,existing_size,new_count,new_size,expected", - [ - (1, 1, 0, 0, 0, 0, True), - (1, 1, 1, 1, 1, 1, False), - (10, 10, 0, 0, 0, 0, True), - (10, 10, 0, 0, 9, 9, True), - (10, 10, 0, 0, 11, 9, True), - (10, 10, 0, 1, 11, 9, True), - (10, 10, 1, 0, 11, 9, False), - (10, 10, 0, 0, 9, 11, True), - (10, 10, 1, 0, 9, 11, True), - (10, 10, 0, 1, 9, 11, False), - (10, 1, 0, 0, 1, 0, True), - (1, 10, 0, 0, 0, 8, True), - (float("inf"), float("inf"), 0, 0, 10000000000.0, 10000000000.0, True), - (8, 8, 0, 0, 10000000000.0, 10000000000.0, True), - (12, 12, 6, 6, 5, 5, True), - (12, 12, 5, 5, 6, 6, True), - (12, 12, 6, 6, 6, 6, True), - (12, 12, 6, 6, 7, 7, False), - (12, 12, 0, 0, 13, 13, True), - (12, 12, 12, 0, 0, 13, True), - (12, 12, 0, 12, 13, 0, True), - (12, 12, 1, 1, 13, 13, False), - (12, 12, 1, 1, 0, 13, False), - (12, 12, 1, 1, 13, 0, False), - ], - ) - def test__has_capacity( - self, - max_count, - max_size, - existing_count, - existing_size, - new_count, - new_size, - expected, - ): - """_has_capacity should return True if the new mutation will will not exceed the max count or size""" - instance = self._make_one(max_count, max_size) - instance._in_flight_mutation_count = existing_count - instance._in_flight_mutation_bytes = existing_size - assert instance._has_capacity(new_count, new_size) == expected - - @pytest.mark.parametrize( - "existing_count,existing_size,added_count,added_size,new_count,new_size", - [ - (0, 0, 0, 0, 0, 0), - (2, 2, 1, 1, 1, 1), - (2, 0, 1, 0, 1, 0), - (0, 2, 0, 1, 0, 1), - (10, 10, 0, 0, 10, 10), - (10, 10, 5, 5, 5, 5), - (0, 0, 1, 1, -1, -1), - ], - ) - def test_remove_from_flow_value_update( - self, - existing_count, - existing_size, - added_count, - added_size, - new_count, - new_size, - ): - """completed mutations should lower the inflight values""" - instance = self._make_one() - instance._in_flight_mutation_count = existing_count - instance._in_flight_mutation_bytes = existing_size - mutation = self._make_mutation(added_count, added_size) - instance.remove_from_flow(mutation) - assert instance._in_flight_mutation_count == new_count - assert instance._in_flight_mutation_bytes == new_size - - def test__remove_from_flow_unlock(self): - """capacity condition should notify after mutation is complete""" - instance = self._make_one(10, 10) - instance._in_flight_mutation_count = 10 - instance._in_flight_mutation_bytes = 10 - - def task_routine(): - with instance._capacity_condition: - instance._capacity_condition.wait_for( - lambda: instance._has_capacity(1, 1) - ) - - if CrossSync._Sync_Impl.is_async: - task = asyncio.create_task(task_routine()) - - def task_alive(): - return not task.done() - - else: - import threading - - thread = threading.Thread(target=task_routine) - thread.start() - task_alive = thread.is_alive - CrossSync._Sync_Impl.sleep(0.05) - assert task_alive() is True - mutation = self._make_mutation(count=0, size=5) - instance.remove_from_flow([mutation]) - CrossSync._Sync_Impl.sleep(0.05) - assert instance._in_flight_mutation_count == 10 - assert instance._in_flight_mutation_bytes == 5 - assert task_alive() is True - instance._in_flight_mutation_bytes = 10 - mutation = self._make_mutation(count=5, size=0) - instance.remove_from_flow([mutation]) - CrossSync._Sync_Impl.sleep(0.05) - assert instance._in_flight_mutation_count == 5 - assert instance._in_flight_mutation_bytes == 10 - assert task_alive() is True - instance._in_flight_mutation_count = 10 - mutation = self._make_mutation(count=5, size=5) - instance.remove_from_flow([mutation]) - CrossSync._Sync_Impl.sleep(0.05) - assert instance._in_flight_mutation_count == 5 - assert instance._in_flight_mutation_bytes == 5 - assert task_alive() is False - - @pytest.mark.parametrize( - "mutations,count_cap,size_cap,expected_results", - [ - ([(5, 5), (1, 1), (1, 1)], 10, 10, [[(5, 5), (1, 1), (1, 1)]]), - ([(1, 1), (1, 1), (1, 1)], 1, 1, [[(1, 1)], [(1, 1)], [(1, 1)]]), - ([(1, 1), (1, 1), (1, 1)], 2, 10, [[(1, 1), (1, 1)], [(1, 1)]]), - ([(1, 1), (1, 1), (1, 1)], 10, 2, [[(1, 1), (1, 1)], [(1, 1)]]), - ( - [(1, 1), (5, 5), (4, 1), (1, 4), (1, 1)], - 5, - 5, - [[(1, 1)], [(5, 5)], [(4, 1), (1, 4)], [(1, 1)]], - ), - ], - ) - def test_add_to_flow(self, mutations, count_cap, size_cap, expected_results): - """Test batching with various flow control settings""" - mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] - instance = self._make_one(count_cap, size_cap) - i = 0 - for batch in instance.add_to_flow(mutation_objs): - expected_batch = expected_results[i] - assert len(batch) == len(expected_batch) - for j in range(len(expected_batch)): - assert len(batch[j].mutations) == expected_batch[j][0] - assert batch[j].size() == expected_batch[j][1] - instance.remove_from_flow(batch) - i += 1 - assert i == len(expected_results) - - @pytest.mark.parametrize( - "mutations,max_limit,expected_results", - [ - ([(1, 1)] * 11, 10, [[(1, 1)] * 10, [(1, 1)]]), - ([(1, 1)] * 10, 1, [[(1, 1)] for _ in range(10)]), - ([(1, 1)] * 10, 2, [[(1, 1), (1, 1)] for _ in range(5)]), - ], - ) - def test_add_to_flow_max_mutation_limits( - self, mutations, max_limit, expected_results - ): - """Test flow control running up against the max API limit - Should submit request early, even if the flow control has room for more""" - async_patch = mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", - max_limit, - ) - sync_patch = mock.patch( - "google.cloud.bigtable.data._sync.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", - max_limit, - ) - with async_patch, sync_patch: - mutation_objs = [ - self._make_mutation(count=m[0], size=m[1]) for m in mutations - ] - instance = self._make_one(float("inf"), float("inf")) - i = 0 - for batch in instance.add_to_flow(mutation_objs): - expected_batch = expected_results[i] - assert len(batch) == len(expected_batch) - for j in range(len(expected_batch)): - assert len(batch[j].mutations) == expected_batch[j][0] - assert batch[j].size() == expected_batch[j][1] - instance.remove_from_flow(batch) - i += 1 - assert i == len(expected_results) - - def test_add_to_flow_oversize(self): - """mutations over the flow control limits should still be accepted""" - instance = self._make_one(2, 3) - large_size_mutation = self._make_mutation(count=1, size=10) - large_count_mutation = self._make_mutation(count=10, size=1) - results = [out for out in instance.add_to_flow([large_size_mutation])] - assert len(results) == 1 - instance.remove_from_flow(results[0]) - count_results = [out for out in instance.add_to_flow(large_count_mutation)] - assert len(count_results) == 1 - - -class TestMutationsBatcher: - def _get_target_class(self): - return MutationsBatcher - - def _make_one(self, table=None, **kwargs): - from google.api_core.exceptions import DeadlineExceeded - from google.api_core.exceptions import ServiceUnavailable - - if table is None: - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 10 - table.default_mutate_rows_retryable_errors = ( - DeadlineExceeded, - ServiceUnavailable, - ) - return self._get_target_class()(table, **kwargs) - - @staticmethod - def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation - - def test_ctor_defaults(self): - with mock.patch.object( - self._get_target_class(), - "_timer_routine", - return_value=CrossSync._Sync_Impl.Future(), - ) as flush_timer_mock: - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = [Exception] - with self._make_one(table) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._max_mutation_count == 100000 - assert instance._flow_control._max_mutation_bytes == 104857600 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert ( - instance._operation_timeout - == table.default_mutate_rows_operation_timeout - ) - assert ( - instance._attempt_timeout - == table.default_mutate_rows_attempt_timeout - ) - assert ( - instance._retryable_errors - == table.default_mutate_rows_retryable_errors - ) - CrossSync._Sync_Impl.yield_to_event_loop() - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == 5 - assert isinstance(instance._flush_timer, CrossSync._Sync_Impl.Future) - - def test_ctor_explicit(self): - """Test with explicit parameters""" - with mock.patch.object( - self._get_target_class(), - "_timer_routine", - return_value=CrossSync._Sync_Impl.Future(), - ) as flush_timer_mock: - table = mock.Mock() - flush_interval = 20 - flush_limit_count = 17 - flush_limit_bytes = 19 - flow_control_max_mutation_count = 1001 - flow_control_max_bytes = 12 - operation_timeout = 11 - attempt_timeout = 2 - retryable_errors = [Exception] - with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - flow_control_max_mutation_count=flow_control_max_mutation_count, - flow_control_max_bytes=flow_control_max_bytes, - batch_operation_timeout=operation_timeout, - batch_attempt_timeout=attempt_timeout, - batch_retryable_errors=retryable_errors, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert ( - instance._flow_control._max_mutation_count - == flow_control_max_mutation_count - ) - assert ( - instance._flow_control._max_mutation_bytes == flow_control_max_bytes - ) - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert instance._operation_timeout == operation_timeout - assert instance._attempt_timeout == attempt_timeout - assert instance._retryable_errors == retryable_errors - CrossSync._Sync_Impl.yield_to_event_loop() - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == flush_interval - assert isinstance(instance._flush_timer, CrossSync._Sync_Impl.Future) - - def test_ctor_no_flush_limits(self): - """Test with None for flush limits""" - with mock.patch.object( - self._get_target_class(), - "_timer_routine", - return_value=CrossSync._Sync_Impl.Future(), - ) as flush_timer_mock: - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = () - flush_interval = None - flush_limit_count = None - flush_limit_bytes = None - with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._staged_entries == [] - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - CrossSync._Sync_Impl.yield_to_event_loop() - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] is None - assert isinstance(instance._flush_timer, CrossSync._Sync_Impl.Future) - - def test_ctor_invalid_values(self): - """Test that timeout values are positive, and fit within expected limits""" - with pytest.raises(ValueError) as e: - self._make_one(batch_operation_timeout=-1) - assert "operation_timeout must be greater than 0" in str(e.value) - with pytest.raises(ValueError) as e: - self._make_one(batch_attempt_timeout=-1) - assert "attempt_timeout must be greater than 0" in str(e.value) - - def test_default_argument_consistency(self): - """We supply default arguments in MutationsBatcherAsync.__init__, and in - table.mutations_batcher. Make sure any changes to defaults are applied to - both places""" - import inspect - - get_batcher_signature = dict( - inspect.signature(Table.mutations_batcher).parameters - ) - get_batcher_signature.pop("self") - batcher_init_signature = dict( - inspect.signature(self._get_target_class()).parameters - ) - batcher_init_signature.pop("table") - assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) - assert len(get_batcher_signature) == 8 - assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) - for arg_name in get_batcher_signature.keys(): - assert ( - get_batcher_signature[arg_name].default - == batcher_init_signature[arg_name].default - ) - - @pytest.mark.parametrize("input_val", [None, 0, -1]) - def test__start_flush_timer_w_empty_input(self, input_val): - """Empty/invalid timer should return immediately""" - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: - with self._make_one() as instance: - if CrossSync._Sync_Impl.is_async: - sleep_obj, sleep_method = (asyncio, "wait_for") - else: - sleep_obj, sleep_method = (instance._closed, "wait") - with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: - result = instance._timer_routine(input_val) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 - assert result is None - - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test__start_flush_timer_call_when_closed(self): - """closed batcher's timer should return immediately""" - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: - with self._make_one() as instance: - instance.close() - flush_mock.reset_mock() - if CrossSync._Sync_Impl.is_async: - sleep_obj, sleep_method = (asyncio, "wait_for") - else: - sleep_obj, sleep_method = (instance._closed, "wait") - with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: - instance._timer_routine(10) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 - - @pytest.mark.parametrize("num_staged", [0, 1, 10]) - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test__flush_timer(self, num_staged): - """Timer should continue to call _schedule_flush in a loop""" - from google.cloud.bigtable.data._sync.cross_sync import CrossSync - - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: - expected_sleep = 12 - with self._make_one(flush_interval=expected_sleep) as instance: - loop_num = 3 - instance._staged_entries = [mock.Mock()] * num_staged - with mock.patch.object( - CrossSync._Sync_Impl, "event_wait" - ) as sleep_mock: - sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] - with pytest.raises(TabError): - self._get_target_class()._timer_routine( - instance, expected_sleep - ) - if CrossSync._Sync_Impl.is_async: - instance._flush_timer = CrossSync._Sync_Impl.Future() - assert sleep_mock.call_count == loop_num + 1 - sleep_kwargs = sleep_mock.call_args[1] - assert sleep_kwargs["timeout"] == expected_sleep - assert flush_mock.call_count == (0 if num_staged == 0 else loop_num) - - def test__flush_timer_close(self): - """Timer should continue terminate after close""" - with mock.patch.object(self._get_target_class(), "_schedule_flush"): - with self._make_one() as instance: - assert instance._flush_timer.done() is False - instance.close() - assert instance._flush_timer.done() is True - - def test_append_closed(self): - """Should raise exception""" - instance = self._make_one() - instance.close() - with pytest.raises(RuntimeError): - instance.append(mock.Mock()) - - def test_append_wrong_mutation(self): - """Mutation objects should raise an exception. - Only support RowMutationEntry""" - from google.cloud.bigtable.data.mutations import DeleteAllFromRow - - with self._make_one() as instance: - expected_error = "invalid mutation type: DeleteAllFromRow. Only RowMutationEntry objects are supported by batcher" - with pytest.raises(ValueError) as e: - instance.append(DeleteAllFromRow()) - assert str(e.value) == expected_error - - def test_append_outside_flow_limits(self): - """entries larger than mutation limits are still processed""" - with self._make_one( - flow_control_max_mutation_count=1, flow_control_max_bytes=1 - ) as instance: - oversized_entry = self._make_mutation(count=0, size=2) - instance.append(oversized_entry) - assert instance._staged_entries == [oversized_entry] - assert instance._staged_count == 0 - assert instance._staged_bytes == 2 - instance._staged_entries = [] - with self._make_one( - flow_control_max_mutation_count=1, flow_control_max_bytes=1 - ) as instance: - overcount_entry = self._make_mutation(count=2, size=0) - instance.append(overcount_entry) - assert instance._staged_entries == [overcount_entry] - assert instance._staged_count == 2 - assert instance._staged_bytes == 0 - instance._staged_entries = [] - - def test_append_flush_runs_after_limit_hit(self): - """If the user appends a bunch of entries above the flush limits back-to-back, - it should still flush in a single task""" - with mock.patch.object( - self._get_target_class(), "_execute_mutate_rows" - ) as op_mock: - with self._make_one(flush_limit_bytes=100) as instance: - - def mock_call(*args, **kwargs): - return [] - - op_mock.side_effect = mock_call - instance.append(self._make_mutation(size=99)) - num_entries = 10 - for _ in range(num_entries): - instance.append(self._make_mutation(size=1)) - instance._wait_for_batch_results(*instance._flush_jobs) - assert op_mock.call_count == 1 - sent_batch = op_mock.call_args[0][0] - assert len(sent_batch) == 2 - assert len(instance._staged_entries) == num_entries - 1 - - @pytest.mark.parametrize( - "flush_count,flush_bytes,mutation_count,mutation_bytes,expect_flush", - [ - (10, 10, 1, 1, False), - (10, 10, 9, 9, False), - (10, 10, 10, 1, True), - (10, 10, 1, 10, True), - (10, 10, 10, 10, True), - (1, 1, 10, 10, True), - (1, 1, 0, 0, False), - ], - ) - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test_append( - self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush - ): - """test appending different mutations, and checking if it causes a flush""" - with self._make_one( - flush_limit_mutation_count=flush_count, flush_limit_bytes=flush_bytes - ) as instance: - assert instance._staged_count == 0 - assert instance._staged_bytes == 0 - assert instance._staged_entries == [] - mutation = self._make_mutation(count=mutation_count, size=mutation_bytes) - with mock.patch.object(instance, "_schedule_flush") as flush_mock: - instance.append(mutation) - assert flush_mock.call_count == bool(expect_flush) - assert instance._staged_count == mutation_count - assert instance._staged_bytes == mutation_bytes - assert instance._staged_entries == [mutation] - instance._staged_entries = [] - - def test_append_multiple_sequentially(self): - """Append multiple mutations""" - with self._make_one( - flush_limit_mutation_count=8, flush_limit_bytes=8 - ) as instance: - assert instance._staged_count == 0 - assert instance._staged_bytes == 0 - assert instance._staged_entries == [] - mutation = self._make_mutation(count=2, size=3) - with mock.patch.object(instance, "_schedule_flush") as flush_mock: - instance.append(mutation) - assert flush_mock.call_count == 0 - assert instance._staged_count == 2 - assert instance._staged_bytes == 3 - assert len(instance._staged_entries) == 1 - instance.append(mutation) - assert flush_mock.call_count == 0 - assert instance._staged_count == 4 - assert instance._staged_bytes == 6 - assert len(instance._staged_entries) == 2 - instance.append(mutation) - assert flush_mock.call_count == 1 - assert instance._staged_count == 6 - assert instance._staged_bytes == 9 - assert len(instance._staged_entries) == 3 - instance._staged_entries = [] - - def test_flush_flow_control_concurrent_requests(self): - """requests should happen in parallel if flow control breaks up single flush into batches""" - import time - - num_calls = 10 - fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] - with self._make_one(flow_control_max_mutation_count=1) as instance: - with mock.patch.object( - instance, "_execute_mutate_rows", CrossSync._Sync_Impl.Mock() - ) as op_mock: - - def mock_call(*args, **kwargs): - CrossSync._Sync_Impl.sleep(0.1) - return [] - - op_mock.side_effect = mock_call - start_time = time.monotonic() - instance._staged_entries = fake_mutations - instance._schedule_flush() - CrossSync._Sync_Impl.sleep(0.01) - for i in range(num_calls): - instance._flow_control.remove_from_flow( - [self._make_mutation(count=1)] - ) - CrossSync._Sync_Impl.sleep(0.01) - instance._wait_for_batch_results(*instance._flush_jobs) - duration = time.monotonic() - start_time - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert duration < 0.5 - assert op_mock.call_count == num_calls - - def test_schedule_flush_no_mutations(self): - """schedule flush should return None if no staged mutations""" - with self._make_one() as instance: - with mock.patch.object(instance, "_flush_internal") as flush_mock: - for i in range(3): - assert instance._schedule_flush() is None - assert flush_mock.call_count == 0 - - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - def test_schedule_flush_with_mutations(self): - """if new mutations exist, should add a new flush task to _flush_jobs""" - with self._make_one() as instance: - with mock.patch.object(instance, "_flush_internal") as flush_mock: - if not CrossSync._Sync_Impl.is_async: - flush_mock.side_effect = lambda x: time.sleep(0.1) - for i in range(1, 4): - mutation = mock.Mock() - instance._staged_entries = [mutation] - instance._schedule_flush() - assert instance._staged_entries == [] - asyncio.sleep(0) - assert instance._staged_entries == [] - assert instance._staged_count == 0 - assert instance._staged_bytes == 0 - assert flush_mock.call_count == 1 - flush_mock.reset_mock() - - def test__flush_internal(self): - """_flush_internal should: - - await previous flush call - - delegate batching to _flow_control - - call _execute_mutate_rows on each batch - - update self.exceptions and self._entries_processed_since_last_raise""" - num_entries = 10 - with self._make_one() as instance: - with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: - with mock.patch.object( - instance._flow_control, "add_to_flow" - ) as flow_mock: - - def gen(x): - yield x - - flow_mock.side_effect = lambda x: gen(x) - mutations = [self._make_mutation(count=1, size=1)] * num_entries - instance._flush_internal(mutations) - assert instance._entries_processed_since_last_raise == num_entries - assert execute_mock.call_count == 1 - assert flow_mock.call_count == 1 - instance._oldest_exceptions.clear() - instance._newest_exceptions.clear() - - def test_flush_clears_job_list(self): - """a job should be added to _flush_jobs when _schedule_flush is called, - and removed when it completes""" - with self._make_one() as instance: - with mock.patch.object( - instance, "_flush_internal", CrossSync._Sync_Impl.Mock() - ) as flush_mock: - if not CrossSync._Sync_Impl.is_async: - flush_mock.side_effect = lambda x: time.sleep(0.1) - mutations = [self._make_mutation(count=1, size=1)] - instance._staged_entries = mutations - assert instance._flush_jobs == set() - new_job = instance._schedule_flush() - assert instance._flush_jobs == {new_job} - if CrossSync._Sync_Impl.is_async: - new_job - else: - new_job.result() - assert instance._flush_jobs == set() - - @pytest.mark.parametrize( - "num_starting,num_new_errors,expected_total_errors", - [ - (0, 0, 0), - (0, 1, 1), - (0, 2, 2), - (1, 0, 1), - (1, 1, 2), - (10, 2, 12), - (10, 20, 20), - ], - ) - def test__flush_internal_with_errors( - self, num_starting, num_new_errors, expected_total_errors - ): - """errors returned from _execute_mutate_rows should be added to internal exceptions""" - from google.cloud.bigtable.data import exceptions - - num_entries = 10 - expected_errors = [ - exceptions.FailedMutationEntryError(mock.Mock(), mock.Mock(), ValueError()) - ] * num_new_errors - with self._make_one() as instance: - instance._oldest_exceptions = [mock.Mock()] * num_starting - with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: - execute_mock.return_value = expected_errors - with mock.patch.object( - instance._flow_control, "add_to_flow" - ) as flow_mock: - - def gen(x): - yield x - - flow_mock.side_effect = lambda x: gen(x) - mutations = [self._make_mutation(count=1, size=1)] * num_entries - instance._flush_internal(mutations) - assert instance._entries_processed_since_last_raise == num_entries - assert execute_mock.call_count == 1 - assert flow_mock.call_count == 1 - found_exceptions = instance._oldest_exceptions + list( - instance._newest_exceptions - ) - assert len(found_exceptions) == expected_total_errors - for i in range(num_starting, expected_total_errors): - assert found_exceptions[i] == expected_errors[i - num_starting] - assert found_exceptions[i].index is None - instance._oldest_exceptions.clear() - instance._newest_exceptions.clear() - - def _mock_gapic_return(self, num=5): - from google.cloud.bigtable_v2.types import MutateRowsResponse - from google.rpc import status_pb2 - - def gen(num): - for i in range(num): - entry = MutateRowsResponse.Entry( - index=i, status=status_pb2.Status(code=0) - ) - yield MutateRowsResponse(entries=[entry]) - - return gen(num) - - def test_timer_flush_end_to_end(self): - """Flush should automatically trigger after flush_interval""" - num_mutations = 10 - mutations = [self._make_mutation(count=2, size=2)] * num_mutations - with self._make_one(flush_interval=0.05) as instance: - instance._table.default_operation_timeout = 10 - instance._table.default_attempt_timeout = 9 - with mock.patch.object( - instance._table.client._gapic_client, "mutate_rows" - ) as gapic_mock: - gapic_mock.side_effect = ( - lambda *args, **kwargs: self._mock_gapic_return(num_mutations) - ) - for m in mutations: - instance.append(m) - assert instance._entries_processed_since_last_raise == 0 - CrossSync._Sync_Impl.sleep(0.1) - assert instance._entries_processed_since_last_raise == num_mutations - - def test__execute_mutate_rows(self): - if CrossSync._Sync_Impl.is_async: - mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" - else: - mutate_path = "_sync.mutations_batcher._MutateRowsOperation" - with mock.patch(f"google.cloud.bigtable.data.{mutate_path}") as mutate_rows: - mutate_rows.return_value = CrossSync._Sync_Impl.Mock() - start_operation = mutate_rows().start - table = mock.Mock() - table.table_name = "test-table" - table.app_profile_id = "test-app-profile" - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - with self._make_one(table) as instance: - batch = [self._make_mutation()] - result = instance._execute_mutate_rows(batch) - assert start_operation.call_count == 1 - args, kwargs = mutate_rows.call_args - assert args[0] == table.client._gapic_client - assert args[1] == table - assert args[2] == batch - kwargs["operation_timeout"] == 17 - kwargs["attempt_timeout"] == 13 - assert result == [] - - def test__execute_mutate_rows_returns_errors(self): - """Errors from operation should be retruned as list""" - from google.cloud.bigtable.data.exceptions import ( - MutationsExceptionGroup, - FailedMutationEntryError, - ) - - if CrossSync._Sync_Impl.is_async: - mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" - else: - mutate_path = "_sync.mutations_batcher._MutateRowsOperation" - with mock.patch( - f"google.cloud.bigtable.data.{mutate_path}.start" - ) as mutate_rows: - err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) - err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) - mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - with self._make_one(table) as instance: - batch = [self._make_mutation()] - result = instance._execute_mutate_rows(batch) - assert len(result) == 2 - assert result[0] == err1 - assert result[1] == err2 - assert result[0].index is None - assert result[1].index is None - - def test__raise_exceptions(self): - """Raise exceptions and reset error state""" - from google.cloud.bigtable.data import exceptions - - expected_total = 1201 - expected_exceptions = [RuntimeError("mock")] * 3 - with self._make_one() as instance: - instance._oldest_exceptions = expected_exceptions - instance._entries_processed_since_last_raise = expected_total - try: - instance._raise_exceptions() - except exceptions.MutationsExceptionGroup as exc: - assert list(exc.exceptions) == expected_exceptions - assert str(expected_total) in str(exc) - assert instance._entries_processed_since_last_raise == 0 - instance._oldest_exceptions, instance._newest_exceptions = ([], []) - instance._raise_exceptions() - - def test___enter__(self): - """Should return self""" - with self._make_one() as instance: - assert instance.__enter__() == instance - - def test___exit__(self): - """aexit should call close""" - with self._make_one() as instance: - with mock.patch.object(instance, "close") as close_mock: - instance.__exit__(None, None, None) - assert close_mock.call_count == 1 - - def test_close(self): - """Should clean up all resources""" - with self._make_one() as instance: - with mock.patch.object(instance, "_schedule_flush") as flush_mock: - with mock.patch.object(instance, "_raise_exceptions") as raise_mock: - instance.close() - assert instance.closed is True - assert instance._flush_timer.done() is True - assert instance._flush_jobs == set() - assert flush_mock.call_count == 1 - assert raise_mock.call_count == 1 - - def test_close_w_exceptions(self): - """Raise exceptions on close""" - from google.cloud.bigtable.data import exceptions - - expected_total = 10 - expected_exceptions = [RuntimeError("mock")] - with self._make_one() as instance: - instance._oldest_exceptions = expected_exceptions - instance._entries_processed_since_last_raise = expected_total - try: - instance.close() - except exceptions.MutationsExceptionGroup as exc: - assert list(exc.exceptions) == expected_exceptions - assert str(expected_total) in str(exc) - assert instance._entries_processed_since_last_raise == 0 - instance._oldest_exceptions, instance._newest_exceptions = ([], []) - - def test__on_exit(self, recwarn): - """Should raise warnings if unflushed mutations exist""" - with self._make_one() as instance: - instance._on_exit() - assert len(recwarn) == 0 - num_left = 4 - instance._staged_entries = [mock.Mock()] * num_left - with pytest.warns(UserWarning) as w: - instance._on_exit() - assert len(w) == 1 - assert "unflushed mutations" in str(w[0].message).lower() - assert str(num_left) in str(w[0].message) - instance._closed.set() - instance._on_exit() - assert len(recwarn) == 0 - instance._staged_entries = [] - - def test_atexit_registration(self): - """Should run _on_exit on program termination""" - import atexit - - with mock.patch.object(atexit, "register") as register_mock: - assert register_mock.call_count == 0 - with self._make_one(): - assert register_mock.call_count == 1 - - def test_timeout_args_passed(self): - """batch_operation_timeout and batch_attempt_timeout should be used - in api calls""" - if CrossSync._Sync_Impl.is_async: - mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" - else: - mutate_path = "_sync.mutations_batcher._MutateRowsOperation" - with mock.patch( - f"google.cloud.bigtable.data.{mutate_path}", - return_value=CrossSync._Sync_Impl.Mock(), - ) as mutate_rows: - expected_operation_timeout = 17 - expected_attempt_timeout = 13 - with self._make_one( - batch_operation_timeout=expected_operation_timeout, - batch_attempt_timeout=expected_attempt_timeout, - ) as instance: - assert instance._operation_timeout == expected_operation_timeout - assert instance._attempt_timeout == expected_attempt_timeout - instance._execute_mutate_rows([self._make_mutation()]) - assert mutate_rows.call_count == 1 - kwargs = mutate_rows.call_args[1] - assert kwargs["operation_timeout"] == expected_operation_timeout - assert kwargs["attempt_timeout"] == expected_attempt_timeout - - @pytest.mark.parametrize( - "limit,in_e,start_e,end_e", - [ - (10, 0, (10, 0), (10, 0)), - (1, 10, (0, 0), (1, 1)), - (10, 1, (0, 0), (1, 0)), - (10, 10, (0, 0), (10, 0)), - (10, 11, (0, 0), (10, 1)), - (3, 20, (0, 0), (3, 3)), - (10, 20, (0, 0), (10, 10)), - (10, 21, (0, 0), (10, 10)), - (2, 1, (2, 0), (2, 1)), - (2, 1, (1, 0), (2, 0)), - (2, 2, (1, 0), (2, 1)), - (3, 1, (3, 1), (3, 2)), - (3, 3, (3, 1), (3, 3)), - (1000, 5, (999, 0), (1000, 4)), - (1000, 5, (0, 0), (5, 0)), - (1000, 5, (1000, 0), (1000, 5)), - ], - ) - def test__add_exceptions(self, limit, in_e, start_e, end_e): - """Test that the _add_exceptions function properly updates the - _oldest_exceptions and _newest_exceptions lists - Args: - - limit: the _exception_list_limit representing the max size of either list - - in_e: size of list of exceptions to send to _add_exceptions - - start_e: a tuple of ints representing the initial sizes of _oldest_exceptions and _newest_exceptions - - end_e: a tuple of ints representing the expected sizes of _oldest_exceptions and _newest_exceptions - """ - from collections import deque - - input_list = [RuntimeError(f"mock {i}") for i in range(in_e)] - mock_batcher = mock.Mock() - mock_batcher._oldest_exceptions = [ - RuntimeError(f"starting mock {i}") for i in range(start_e[0]) - ] - mock_batcher._newest_exceptions = deque( - [RuntimeError(f"starting mock {i}") for i in range(start_e[1])], - maxlen=limit, - ) - mock_batcher._exception_list_limit = limit - mock_batcher._exceptions_since_last_raise = 0 - self._get_target_class()._add_exceptions(mock_batcher, input_list) - assert len(mock_batcher._oldest_exceptions) == end_e[0] - assert len(mock_batcher._newest_exceptions) == end_e[1] - assert mock_batcher._exceptions_since_last_raise == in_e - oldest_list_diff = end_e[0] - start_e[0] - newest_list_diff = min(max(in_e - oldest_list_diff, 0), limit) - for i in range(oldest_list_diff): - assert mock_batcher._oldest_exceptions[i + start_e[0]] == input_list[i] - for i in range(1, newest_list_diff + 1): - assert mock_batcher._newest_exceptions[-i] == input_list[-i] - - @pytest.mark.parametrize( - "input_retryables,expected_retryables", - [ - ( - TABLE_DEFAULT.READ_ROWS, - [ - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - core_exceptions.Aborted, - ], - ), - ( - TABLE_DEFAULT.DEFAULT, - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ), - ( - TABLE_DEFAULT.MUTATE_ROWS, - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], - ), - ([], []), - ([4], [core_exceptions.DeadlineExceeded]), - ], - ) - def test_customizable_retryable_errors(self, input_retryables, expected_retryables): - """Test that retryable functions support user-configurable arguments, and that the configured retryables are passed - down to the gapic layer.""" - with mock.patch.object( - google.api_core.retry, "if_exception_type" - ) as predicate_builder_mock: - with mock.patch.object( - CrossSync._Sync_Impl, "retry_target" - ) as retry_fn_mock: - table = None - with mock.patch("asyncio.create_task"): - table = Table(mock.Mock(), "instance", "table") - with self._make_one( - table, batch_retryable_errors=input_retryables - ) as instance: - assert instance._retryable_errors == expected_retryables - expected_predicate = expected_retryables.__contains__ - predicate_builder_mock.return_value = expected_predicate - retry_fn_mock.side_effect = RuntimeError("stop early") - mutation = self._make_mutation(count=1, size=1) - instance._execute_mutate_rows([mutation]) - predicate_builder_mock.assert_called_once_with( - *expected_retryables, _MutateRowsIncomplete - ) - retry_call_args = retry_fn_mock.call_args_list[0].args - assert retry_call_args[1] is expected_predicate diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync/test_read_rows_acceptance.py deleted file mode 100644 index 6baef4a4d..000000000 --- a/tests/unit/data/_sync/test_read_rows_acceptance.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This file is automatically generated by CrossSync. Do not edit manually. -from __future__ import annotations -import os -import warnings -import pytest -import mock -from itertools import zip_longest -from google.cloud.bigtable_v2 import ReadRowsResponse -from google.cloud.bigtable.data.exceptions import InvalidChunk -from google.cloud.bigtable.data.row import Row -from ...v2_client.test_row_merger import ReadRowsTest, TestFile -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -if CrossSync._Sync_Impl.is_async: - pass -else: - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - from google.cloud.bigtable.data._sync.client import BigtableDataClient - - -class TestReadRowsAcceptance: - @staticmethod - def _get_operation_class(): - return _ReadRowsOperation - - @staticmethod - def _get_client_class(): - return BigtableDataClient - - def parse_readrows_acceptance_tests(): - dirname = os.path.dirname(__file__) - filename = os.path.join(dirname, "../read-rows-acceptance-test.json") - with open(filename) as json_file: - test_json = TestFile.from_json(json_file.read()) - return test_json.read_rows_tests - - @staticmethod - def extract_results_from_row(row: Row): - results = [] - for family, col, cells in row.items(): - for cell in cells: - results.append( - ReadRowsTest.Result( - row_key=row.row_key, - family_name=family, - qualifier=col, - timestamp_micros=cell.timestamp_ns // 1000, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - ) - return results - - @staticmethod - def _coro_wrapper(stream): - return stream - - def _process_chunks(self, *chunks): - def _row_stream(): - yield ReadRowsResponse(chunks=chunks) - - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = None - chunker = self._get_operation_class().chunk_stream( - instance, self._coro_wrapper(_row_stream()) - ) - merger = self._get_operation_class().merge_rows(chunker) - results = [] - for row in merger: - results.append(row) - return results - - @pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description - ) - def test_row_merger_scenario(self, test_case: ReadRowsTest): - def _scenerio_stream(): - for chunk in test_case.chunks: - yield ReadRowsResponse(chunks=[chunk]) - - try: - results = [] - instance = mock.Mock() - instance._last_yielded_row_key = None - instance._remaining_count = None - chunker = self._get_operation_class().chunk_stream( - instance, self._coro_wrapper(_scenerio_stream()) - ) - merger = self._get_operation_class().merge_rows(chunker) - for row in merger: - for cell in row: - cell_result = ReadRowsTest.Result( - row_key=cell.row_key, - family_name=cell.family, - qualifier=cell.qualifier, - timestamp_micros=cell.timestamp_micros, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - @pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description - ) - def test_read_rows_scenario(self, test_case: ReadRowsTest): - def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): - from google.cloud.bigtable_v2 import ReadRowsResponse - - class mock_stream: - def __init__(self, chunk_list): - self.chunk_list = chunk_list - self.idx = -1 - - def __aiter__(self): - return self - - def __iter__(self): - return self - - def __anext__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - chunk = self.chunk_list[self.idx] - return ReadRowsResponse(chunks=[chunk]) - raise CrossSync._Sync_Impl.StopIteration - - def __next__(self): - return self.__anext__() - - def cancel(self): - pass - - return mock_stream(chunk_list) - - with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - client = self._get_client_class()() - try: - table = client.get_table("instance", "table") - results = [] - with mock.patch.object( - table.client._gapic_client, "read_rows" - ) as read_rows: - read_rows.return_value = _make_gapic_stream(test_case.chunks) - for row in table.read_rows_stream(query={}): - for cell in row: - cell_result = ReadRowsTest.Result( - row_key=cell.row_key, - family_name=cell.family, - qualifier=cell.qualifier, - timestamp_micros=cell.timestamp_micros, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - finally: - client.close() - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - def test_out_of_order_rows(self): - def _row_stream(): - yield ReadRowsResponse(last_scanned_row_key=b"a") - - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = b"b" - chunker = self._get_operation_class().chunk_stream( - instance, self._coro_wrapper(_row_stream()) - ) - merger = self._get_operation_class().merge_rows(chunker) - with pytest.raises(InvalidChunk): - for _ in merger: - pass - - def test_bare_reset(self): - first_chunk = ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk( - row_key=b"a", family_name="f", qualifier=b"q", value=b"v" - ) - ) - with pytest.raises(InvalidChunk): - self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") - ), - ) - with pytest.raises(InvalidChunk): - self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, family_name="f") - ), - ) - with pytest.raises(InvalidChunk): - self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") - ), - ) - with pytest.raises(InvalidChunk): - self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) - ), - ) - with pytest.raises(InvalidChunk): - self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) - ), - ) - with pytest.raises(InvalidChunk): - self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, value=b"v") - ), - ) - - def test_missing_family(self): - with pytest.raises(InvalidChunk): - self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - qualifier=b"q", - timestamp_micros=1000, - value=b"v", - commit_row=True, - ) - ) - - def test_mid_cell_row_key_change(self): - with pytest.raises(InvalidChunk): - self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), - ) - - def test_mid_cell_family_change(self): - with pytest.raises(InvalidChunk): - self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - family_name="f2", value=b"v", commit_row=True - ), - ) - - def test_mid_cell_qualifier_change(self): - with pytest.raises(InvalidChunk): - self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - qualifier=b"q2", value=b"v", commit_row=True - ), - ) - - def test_mid_cell_timestamp_change(self): - with pytest.raises(InvalidChunk): - self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - timestamp_micros=2000, value=b"v", commit_row=True - ), - ) - - def test_mid_cell_labels_change(self): - with pytest.raises(InvalidChunk): - self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), - ) From 245bd0894a61b3f309d79cdfa040dd344eabc29f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 14:54:16 -0700 Subject: [PATCH 146/360] removed conversion decorators --- .../bigtable/data/_async/_mutate_rows.py | 11 ---- .../cloud/bigtable/data/_async/_read_rows.py | 8 --- google/cloud/bigtable/data/_async/client.py | 43 ------------- .../bigtable/data/_async/mutations_batcher.py | 22 ------- .../cloud/bigtable/data/_sync/cross_sync.py | 11 ---- tests/system/data/test_system_async.py | 11 ---- tests/unit/data/_async/test__mutate_rows.py | 4 -- tests/unit/data/_async/test__read_rows.py | 11 ---- tests/unit/data/_async/test_client.py | 64 ------------------- .../data/_async/test_mutations_batcher.py | 15 ----- .../data/_async/test_read_rows_acceptance.py | 11 ---- 11 files changed, 211 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 6d4d2f2e8..2e7181695 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -62,9 +62,6 @@ class _EntryWithProto: # noqa: F811 proto: types_pb.MutateRowsRequest.Entry -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", -) class _MutateRowsOperationAsync: """ MutateRowsOperation manages the logic of sending a set of row mutations, @@ -84,12 +81,6 @@ class _MutateRowsOperationAsync: If not specified, the request will run until operation_timeout is reached. """ - @CrossSync.convert( - replace_symbols={ - "BigtableAsyncClient": "BigtableClient", - "TableAsync": "Table", - } - ) def __init__( self, gapic_client: "BigtableAsyncClient", @@ -141,7 +132,6 @@ def __init__( self.remaining_indices = list(range(len(self.mutations))) self.errors: dict[int, list[Exception]] = {} - @CrossSync.convert async def start(self): """ Start the operation, and run until completion @@ -174,7 +164,6 @@ async def start(self): if all_errors: raise MutationsExceptionGroup(all_errors, len(self.mutations)) - @CrossSync.convert async def _run_attempt(self): """ Run a single attempt of the mutate_rows rpc. diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 8c982427c..9285d5f6f 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -51,9 +51,6 @@ def __init__(self, chunk): self.chunk = chunk -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", -) class _ReadRowsOperationAsync: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream @@ -85,7 +82,6 @@ class _ReadRowsOperationAsync: "_remaining_count", ) - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def __init__( self, query: ReadRowsQuery, @@ -166,7 +162,6 @@ def _read_rows_attempt(self) -> CrossSync.Iterable[Row]: chunked_stream = self.chunk_stream(gapic_stream) return self.merge_rows(chunked_stream) - @CrossSync.convert async def chunk_stream( self, stream: CrossSync.Awaitable[CrossSync.Iterable[ReadRowsResponsePB]] ) -> CrossSync.Iterable[ReadRowsResponsePB.CellChunk]: @@ -219,9 +214,6 @@ async def chunk_stream( current_key = None @staticmethod - @CrossSync.convert( - replace_symbols={"__aiter__": "__iter__", "__anext__": "__next__"} - ) async def merge_rows( chunks: CrossSync.Iterable[ReadRowsResponsePB.CellChunk] | None, ) -> CrossSync.Iterable[Row]: diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index e7d84ebf1..1c52f83c2 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -110,17 +110,7 @@ from google.cloud.bigtable.data._helpers import ShardedQuery -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync.client.BigtableDataClient", -) class BigtableDataClientAsync(ClientWithProject): - @CrossSync.convert( - replace_symbols={ - "BigtableAsyncClient": "BigtableClient", - "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", - "AsyncPooledChannel": "PooledChannel", - } - ) def __init__( self, *, @@ -266,7 +256,6 @@ def _start_background_channel_refresh(self) -> None: lambda _: self._channel_refresh_tasks.remove(refresh_task) if refresh_task in self._channel_refresh_tasks else None ) - @CrossSync.convert async def close(self, timeout: float | None = None): """ Cancel all background tasks @@ -279,7 +268,6 @@ async def close(self, timeout: float | None = None): self._executor.shutdown(wait=False) await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) - @CrossSync.convert async def _ping_and_warm_instances( self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None ) -> list[BaseException | None]: @@ -321,7 +309,6 @@ async def _ping_and_warm_instances( ) return [r or None for r in result_list] - @CrossSync.convert async def _manage_channel( self, channel_idx: int, @@ -378,7 +365,6 @@ async def _manage_channel( next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = next_refresh - (time.monotonic() - start_timestamp) - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: """ Registers an instance with the client, and warms the channel pool @@ -409,7 +395,6 @@ async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: # refresh tasks aren't active. start them as background tasks self._start_background_channel_refresh() - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def _remove_instance_registration( self, instance_id: str, owner: TableAsync ) -> bool: @@ -440,7 +425,6 @@ async def _remove_instance_registration( except KeyError: return False - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: """ Returns a table instance for making data API requests. All arguments are passed @@ -482,18 +466,15 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs """ return TableAsync(self, instance_id, table_id, *args, **kwargs) - @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): self._start_background_channel_refresh() return self - @CrossSync.convert(sync_name="__exit__", replace_symbols={"__aexit__": "__exit__"}) async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) -@CrossSync.export_sync(path="google.cloud.bigtable.data._sync.client.Table") class TableAsync: """ Main Data API surface @@ -502,9 +483,6 @@ class TableAsync: each call """ - @CrossSync.convert( - replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} - ) def __init__( self, client: BigtableDataClientAsync, @@ -625,12 +603,6 @@ def __init__( f"{self.__class__.__name__} must be created within an async event loop context." ) from e - @CrossSync.convert( - replace_symbols={ - "AsyncIterable": "Iterable", - "_ReadRowsOperationAsync": "_ReadRowsOperation", - } - ) async def read_rows_stream( self, query: ReadRowsQuery, @@ -681,7 +653,6 @@ async def read_rows_stream( ) return row_merger.start_operation() - @CrossSync.convert async def read_rows( self, query: ReadRowsQuery, @@ -729,7 +700,6 @@ async def read_rows( ) return [row async for row in row_generator] - @CrossSync.convert async def read_row( self, row_key: str | bytes, @@ -779,7 +749,6 @@ async def read_row( return None return results[0] - @CrossSync.convert async def read_rows_sharded( self, sharded_query: ShardedQuery, @@ -879,7 +848,6 @@ async def read_rows_with_semaphore(query): ) return results_list - @CrossSync.convert async def row_exists( self, row_key: str | bytes, @@ -928,7 +896,6 @@ async def row_exists( ) return len(results) > 0 - @CrossSync.convert async def sample_row_keys( self, *, @@ -1001,7 +968,6 @@ async def execute_rpc(): exception_factory=_helpers._retry_exception_factory, ) - @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) def mutations_batcher( self, *, @@ -1051,7 +1017,6 @@ def mutations_batcher( batch_retryable_errors=batch_retryable_errors, ) - @CrossSync.convert async def mutate_row( self, row_key: str | bytes, @@ -1130,9 +1095,6 @@ async def mutate_row( exception_factory=_helpers._retry_exception_factory, ) - @CrossSync.convert( - replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"} - ) async def bulk_mutate_rows( self, mutation_entries: list[RowMutationEntry], @@ -1188,7 +1150,6 @@ async def bulk_mutate_rows( ) await operation.start() - @CrossSync.convert async def check_and_mutate_row( self, row_key: str | bytes, @@ -1255,7 +1216,6 @@ async def check_and_mutate_row( ) return result.predicate_matched - @CrossSync.convert async def read_modify_write_row( self, row_key: str | bytes, @@ -1306,7 +1266,6 @@ async def read_modify_write_row( # construct Row from result return Row._from_pb(result.row) - @CrossSync.convert async def close(self): """ Called to close the Table instance and release any resources held by it. @@ -1315,7 +1274,6 @@ async def close(self): self._register_instance_future.cancel() await self.client._remove_instance_registration(self.instance_id, self) - @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """ Implement async context manager protocol @@ -1327,7 +1285,6 @@ async def __aenter__(self): await self._register_instance_future return self - @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc_val, exc_tb): """ Implement async context manager protocol diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index b9a6a3339..07eac0e26 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -51,9 +51,6 @@ from google.cloud.bigtable.data._sync.client import Table # noqa: F401 -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync.mutations_batcher._FlowControl" -) class _FlowControlAsync: """ Manages flow control for batched mutations. Mutations are registered against @@ -110,7 +107,6 @@ def _has_capacity(self, additional_count: int, additional_size: int) -> bool: new_count = self._in_flight_mutation_count + additional_count return new_size <= acceptable_size and new_count <= acceptable_count - @CrossSync.convert async def remove_from_flow( self, mutations: RowMutationEntry | list[RowMutationEntry] ) -> None: @@ -132,7 +128,6 @@ async def remove_from_flow( async with self._capacity_condition: self._capacity_condition.notify_all() - @CrossSync.convert async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): """ Generator function that registers mutations with flow control. As mutations @@ -182,10 +177,6 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] yield mutations[start_idx:end_idx] -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", - mypy_ignore=["unreachable"], -) class MutationsBatcherAsync: """ Allows users to send batches using context manager API: @@ -217,9 +208,6 @@ class MutationsBatcherAsync: Defaults to the Table's default_mutate_rows_retryable_errors. """ - @CrossSync.convert( - replace_symbols={"TableAsync": "Table", "_FlowControlAsync": "_FlowControl"} - ) def __init__( self, table: TableAsync, @@ -275,7 +263,6 @@ def __init__( # clean up on program exit atexit.register(self._on_exit) - @CrossSync.convert async def _timer_routine(self, interval: float | None) -> None: """ Set up a background task to flush the batcher every interval seconds @@ -296,7 +283,6 @@ async def _timer_routine(self, interval: float | None) -> None: if not self._closed.is_set() and self._staged_entries: self._schedule_flush() - @CrossSync.convert async def append(self, mutation_entry: RowMutationEntry): """ Add a new set of mutations to the internal queue @@ -346,7 +332,6 @@ def _schedule_flush(self) -> CrossSync.Future[None] | None: return new_task return None - @CrossSync.convert async def _flush_internal(self, new_entries: list[RowMutationEntry]): """ Flushes a set of mutations to the server, and updates internal state @@ -367,9 +352,6 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): self._entries_processed_since_last_raise += len(new_entries) self._add_exceptions(found_exceptions) - @CrossSync.convert( - replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"} - ) async def _execute_mutate_rows( self, batch: list[RowMutationEntry] ) -> list[FailedMutationEntryError]: @@ -450,12 +432,10 @@ def _raise_exceptions(self): entry_count=entry_count, ) - @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """Allow use of context manager API""" return self - @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc, tb): """ Allow use of context manager API. @@ -472,7 +452,6 @@ def closed(self) -> bool: """ return self._closed.is_set() - @CrossSync.convert async def close(self): """ Flush queue and clean up resources @@ -500,7 +479,6 @@ def _on_exit(self): ) @staticmethod - @CrossSync.convert async def _wait_for_batch_results( *tasks: CrossSync.Future[list[FailedMutationEntryError]] | CrossSync.Future[None], diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index dd87a63b5..e3f1169fa 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -159,17 +159,6 @@ class CrossSync(metaclass=_DecoratorMeta): Generator: TypeAlias = AsyncGenerator _decorators: list[AstDecorator] = [ - AstDecorator("export_sync", # decorate classes to convert - required_keywords=["path"], # otput path for generated sync class - replace_symbols={}, # replace specific symbols across entire class - mypy_ignore=(), # set of mypy error codes to ignore in output file - include_file_imports=True # when True, import statements from top of file will be included in output file - ), - AstDecorator("convert", # decorate methods to convert from async to sync - sync_name=None, # use a new name for the sync class - replace_symbols={}, # replace specific symbols within the function - ), - AstDecorator("drop_method"), # decorate methods to drop in sync version of class AstDecorator("pytest", inner_decorator=pytest_mark_asyncio), # decorate test methods to run with pytest-asyncio AstDecorator("pytest_fixture", # decorate test methods to run with pytest fixture inner_decorator=pytest_asyncio_fixture, diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 32ff5f49c..ed3435e39 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -34,7 +34,6 @@ TEST_FAMILY = "test-family" TEST_FAMILY_2 = "test-family-2" -@CrossSync.export_sync(path="tests.system.data.test_system.TempRowBuilder") class TempRowBuilderAsync: """ Used to add rows to a table for testing purposes. @@ -44,7 +43,6 @@ def __init__(self, table): self.rows = [] self.table = table - @CrossSync.convert async def add_row( self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" ): @@ -68,7 +66,6 @@ async def add_row( await self.table.client._gapic_client.mutate_row(request) self.rows.append(row_key) - @CrossSync.convert async def delete_rows(self): if self.rows: request = { @@ -81,17 +78,14 @@ async def delete_rows(self): await self.table.client._gapic_client.mutate_rows(request) -@CrossSync.export_sync(path="tests.system.data.test_system.TestSystem") class TestSystemAsync: - @CrossSync.convert(replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}) @CrossSync.pytest_fixture(scope="session") async def client(self): project = os.getenv("GOOGLE_CLOUD_PROJECT") or None async with BigtableDataClientAsync(project=project, pool_size=4) as client: yield client - @CrossSync.convert @CrossSync.pytest_fixture(scope="session") async def table(self, client, table_id, instance_id): async with client.get_table( @@ -139,7 +133,6 @@ def cluster_config(self, project_id): } return cluster - @CrossSync.convert @pytest.mark.usefixtures("table") async def _retrieve_cell_value(self, table, row_key): """ @@ -153,7 +146,6 @@ async def _retrieve_cell_value(self, table, row_key): cell = row.cells[0] return cell.value - @CrossSync.convert async def _create_row_and_mutation( self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" ): @@ -174,7 +166,6 @@ async def _create_row_and_mutation( mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) return row_key, mutation - @CrossSync.convert(replace_symbols={"TempRowBuilderAsync": "TempRowBuilder"}) @CrossSync.pytest_fixture(scope="function") async def temp_rows(self, table): builder = TempRowBuilderAsync(table) @@ -665,7 +656,6 @@ async def test_check_and_mutate_empty_request(self, client, table): assert "No mutations provided" in str(e.value) @pytest.mark.usefixtures("table") - @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @@ -854,7 +844,6 @@ async def test_read_rows_with_filter(self, table, temp_rows): assert row[0].labels == [expected_label] @pytest.mark.usefixtures("table") - @CrossSync.convert(replace_symbols={"__anext__": "__next__", "aclose": "close"}) @CrossSync.pytest async def test_read_rows_stream_close(self, table, temp_rows): """ diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 292cbd692..b7016be81 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -28,9 +28,6 @@ import mock # type: ignore -@CrossSync.export_sync( - path="tests.unit.data._sync.test__mutate_rows.TestMutateRowsOperation", -) class TestMutateRowsOperation: def _target_class(self): if CrossSync.is_async: @@ -62,7 +59,6 @@ def _make_mutation(self, count=1, size=1): mutation.mutations = [mock.Mock()] * count return mutation - @CrossSync.convert async def _mock_stream(self, mutation_list, error_dict): for idx, entry in enumerate(mutation_list): code = error_dict.get(idx, 0) diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 076e86788..e04436ee1 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -34,9 +34,6 @@ TEST_LABELS = ["label1", "label2"] -@CrossSync.export_sync( - path="tests.unit.data._sync.test__read_rows.TestReadRowsOperation", -) class TestReadRowsOperationAsync: """ Tests helper functions in the ReadRowsOperation class @@ -45,9 +42,6 @@ class TestReadRowsOperationAsync: """ @staticmethod - @CrossSync.convert( - replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} - ) def _get_target_class(): return _ReadRowsOperationAsync @@ -332,10 +326,6 @@ async def mock_stream(): assert "emit count exceeds row limit" in str(e.value) @CrossSync.pytest - @CrossSync.convert( - sync_name="test_close", - replace_symbols={"aclose": "close", "__anext__": "__next__"}, - ) async def test_aclose(self): """ should be able to close a stream safely with aclose. @@ -365,7 +355,6 @@ async def mock_stream(): await wrapped_gen.__anext__() @CrossSync.pytest - @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) async def test_retryable_ignore_repeated_rows(self): """ Duplicate rows should cause an invalid chunk error diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 0f5775fac..5370f35d3 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -76,21 +76,8 @@ ) -@CrossSync.export_sync( - path="tests.unit.data._sync.test_client.TestBigtableDataClient", - replace_symbols={ - "TestTableAsync": "TestTable", - "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", - "grpc_helpers_async": "grpc_helpers", - "PooledChannelAsync": "PooledChannel", - "BigtableAsyncClient": "BigtableClient", - }, -) class TestBigtableDataClientAsync: @staticmethod - @CrossSync.convert( - replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} - ) def _get_target_class(): return BigtableDataClientAsync @@ -300,7 +287,6 @@ async def test_channel_pool_replace(self): assert client.transport._grpc_channel._pool[i] != start_pool[i] await client.close() - @CrossSync.drop_method @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test__start_background_channel_refresh_sync(self): # should raise RuntimeError if called in a sync context @@ -344,7 +330,6 @@ async def test__start_background_channel_refresh(self, pool_size): ping_and_warm.assert_any_call(channel) await client.close() - @CrossSync.drop_method @CrossSync.pytest @pytest.mark.skipif( sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" @@ -1125,7 +1110,6 @@ async def test_context_manager(self): # actually close the client await true_close - @CrossSync.drop_method def test_client_ctor_sync(self): # initializing client in a sync context should raise RuntimeError @@ -1141,16 +1125,11 @@ def test_client_ctor_sync(self): assert client._channel_refresh_tasks == [] -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestTable") class TestTableAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @staticmethod - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def _get_target_class(): return TableAsync @@ -1272,7 +1251,6 @@ async def test_table_ctor_invalid_timeout_values(self): assert "operation_timeout must be greater than 0" in str(e.value) await client.close() - @CrossSync.drop_method def test_table_ctor_sync(self): # initializing client in a sync context should raise RuntimeError client = mock.Mock() @@ -1422,7 +1400,6 @@ async def test_customizable_retryable_errors( ) @pytest.mark.parametrize("include_app_profile", [True, False]) @CrossSync.pytest - @CrossSync.convert(replace_symbols={"BigtableAsyncClient": "BigtableClient"}) async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): """check that all requests attach proper metadata headers""" profile = "profile" if include_app_profile else None @@ -1455,26 +1432,18 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ assert "app_profile_id=" not in goog_metadata -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRows") class TestReadRowsAsync: """ Tests for table.read_rows and related methods. """ @staticmethod - @CrossSync.convert( - replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} - ) def _get_operation_class(): return _ReadRowsOperationAsync - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) - @CrossSync.convert(replace_symbols={"TestTableAsync": "TestTable"}) def _make_table(self, *args, **kwargs): client_mock = mock.Mock() client_mock._register_instance.side_effect = ( @@ -1522,7 +1491,6 @@ def _make_chunk(*args, **kwargs): return ReadRowsResponse.CellChunk(*args, **kwargs) @staticmethod - @CrossSync.convert async def _make_gapic_stream( chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0, @@ -1561,7 +1529,6 @@ def cancel(self): return mock_stream(chunk_list, sleep_time) - @CrossSync.convert async def execute_fn(self, table, *args, **kwargs): return await table.read_rows(*args, **kwargs) @@ -1973,11 +1940,7 @@ async def test_row_exists(self, return_value, expected_result): assert query.filter._to_dict() == expected_filter -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRowsSharded") class TestReadRowsShardedAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -1990,7 +1953,6 @@ async def test_read_rows_sharded_empty_query(self): assert "empty sharded_query" in str(exc.value) @CrossSync.pytest - @CrossSync.convert(replace_symbols={"TestReadRowsAsync": "TestReadRows"}) async def test_read_rows_sharded_multiple_queries(self): """ Test with multiple queries. Should return results from both @@ -2198,15 +2160,10 @@ async def mock_call(*args, **kwargs): ) -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestSampleRowKeys") class TestSampleRowKeysAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) - @CrossSync.convert async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): from google.cloud.bigtable_v2.types import SampleRowKeysResponse @@ -2354,13 +2311,7 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio await table.sample_row_keys() -@CrossSync.export_sync( - path="tests.unit.data._sync.test_client.TestMutateRow", -) class TestMutateRowAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2535,17 +2486,10 @@ async def test_mutate_row_no_mutations(self, mutations): assert e.value.args[0] == "No mutations provided" -@CrossSync.export_sync( - path="tests.unit.data._sync.test_client.TestBulkMutateRows", -) class TestBulkMutateRowsAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) - @CrossSync.convert async def _mock_response(self, response_list): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -2921,11 +2865,7 @@ async def test_bulk_mutate_error_recovery(self): await table.bulk_mutate_rows(entries, operation_timeout=1000) -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestCheckAndMutateRow") class TestCheckAndMutateRowAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -3076,11 +3016,7 @@ async def test_check_and_mutate_mutations_parsing(self): ) -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadModifyWriteRow") class TestReadModifyWriteRowAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 85adae2d2..6386adc7f 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -43,10 +43,8 @@ import mock # type: ignore -@CrossSync.export_sync(path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl") class Test_FlowControl: @staticmethod - @CrossSync.convert(replace_symbols={"_FlowControlAsync": "_FlowControl"}) def _target_class(): return _FlowControlAsync @@ -321,11 +319,7 @@ async def test_add_to_flow_oversize(self): assert len(count_results) == 1 -@CrossSync.export_sync( - path="tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" -) class TestMutationsBatcherAsync: - @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) def _get_target_class(self): return MutationsBatcherAsync @@ -487,7 +481,6 @@ async def test_ctor_invalid_values(self): self._make_one(batch_attempt_timeout=-1) assert "attempt_timeout must be greater than 0" in str(e.value) - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def test_default_argument_consistency(self): """ We supply default arguments in MutationsBatcherAsync.__init__, and in @@ -908,7 +901,6 @@ async def gen(x): instance._oldest_exceptions.clear() instance._newest_exceptions.clear() - @CrossSync.convert async def _mock_gapic_return(self, num=5): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -1024,18 +1016,12 @@ async def test__raise_exceptions(self): instance._raise_exceptions() @CrossSync.pytest - @CrossSync.convert( - sync_name="test___enter__", replace_symbols={"__aenter__": "__enter__"} - ) async def test___aenter__(self): """Should return self""" async with self._make_one() as instance: assert await instance.__aenter__() == instance @CrossSync.pytest - @CrossSync.convert( - sync_name="test___exit__", replace_symbols={"__aexit__": "__exit__"} - ) async def test___aexit__(self): """aexit should call close""" async with self._make_one() as instance: @@ -1219,7 +1205,6 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): ([4], [core_exceptions.DeadlineExceeded]), ], ) - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def test_customizable_retryable_errors( self, input_retryables, expected_retryables ): diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 7cdd2c180..28d8c56f9 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -39,21 +39,12 @@ ) from google.cloud.bigtable.data._sync.client import BigtableDataClient # noqa: F401 -@CrossSync.export_sync( - path="tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", -) class TestReadRowsAcceptanceAsync: @staticmethod - @CrossSync.convert( - replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} - ) def _get_operation_class(): return _ReadRowsOperationAsync @staticmethod - @CrossSync.convert( - replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} - ) def _get_client_class(): return BigtableDataClientAsync @@ -83,11 +74,9 @@ def extract_results_from_row(row: Row): return results @staticmethod - @CrossSync.convert async def _coro_wrapper(stream): return stream - @CrossSync.convert async def _process_chunks(self, *chunks): async def _row_stream(): yield ReadRowsResponse(chunks=chunks) From 74a69c330c192f05a7779f1c44a84a5139d652ec Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 14:54:49 -0700 Subject: [PATCH 147/360] removed main function from cross_sync --- google/cloud/bigtable/data/_sync/cross_sync.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index e3f1169fa..4ce65c38e 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -389,20 +389,3 @@ def create_task( @staticmethod def yield_to_event_loop() -> None: pass - - -if __name__ == "__main__": - import glob - from google.cloud.bigtable.data._sync import transformers - - # find all cross_sync decorated classes - search_root = sys.argv[1] - # cross_sync_classes = load_classes_from_dir(search_root)\ - files = glob.glob(search_root + "/**/*.py", recursive=True) - artifacts: set[transformers.CrossSyncFileArtifact] = set() - for file in files: - converter = transformers.CrossSyncClassDecoratorHandler(file) - converter.convert_file(artifacts) - print(artifacts) - for artifact in artifacts: - artifact.render(save_to_disk=True) From 276add1929a5d2573f59dbf1f80fe1026b6d917a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 14:55:46 -0700 Subject: [PATCH 148/360] removed sync classes from __init__.py --- google/cloud/bigtable/data/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index b52d36b50..66fe3479b 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -48,9 +48,6 @@ __version__: str = package_version.__version__ __all__ = ( - "BigtableDataClient", - "Table", - "MutationsBatcher", "BigtableDataClientAsync", "TableAsync", "MutationsBatcherAsync", From d3906bfad3773dd09ec950450f67697bc65c5e03 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 14:59:55 -0700 Subject: [PATCH 149/360] removed unused file --- sync_surface_generator.py | 384 -------------------------------------- 1 file changed, 384 deletions(-) delete mode 100644 sync_surface_generator.py diff --git a/sync_surface_generator.py b/sync_surface_generator.py deleted file mode 100644 index 091a74492..000000000 --- a/sync_surface_generator.py +++ /dev/null @@ -1,384 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import annotations - -import inspect -import ast -import textwrap -import importlib -import yaml -from pathlib import Path -import os - -from black import format_str, FileMode -import autoflake -""" -This module allows us to generate a synchronous API surface from our asyncio surface. -""" - -class AsyncToSyncTransformer(ast.NodeTransformer): - """ - This class is used to transform async classes into sync classes. - Generated classes are abstract, and must be subclassed to be used. - This is to ensure any required customizations from - outside of this autogeneration system are always applied - """ - - def __init__(self, *, name=None, asyncio_replacements=None, text_replacements=None, drop_methods=None, pass_methods=None, error_methods=None, replace_methods=None): - """ - Args: - - name: the name of the class being processed. Just used in exceptions - - asyncio_replacements: asyncio functionality to replace - - text_replacements: dict of text to replace directly in the source code and docstrings - - drop_methods: list of method names to drop from the class - - pass_methods: list of method names to replace with "pass" in the class - - error_methods: list of method names to replace with "raise NotImplementedError" in the class - - replace_methods: dict of method names to replace with custom code - """ - self.name = name - self.asyncio_replacements = asyncio_replacements or {} - self.text_replacements = text_replacements or {} - self.drop_methods = drop_methods or [] - self.pass_methods = pass_methods or [] - self.error_methods = error_methods or [] - self.replace_methods = replace_methods or {} - - def update_docstring(self, docstring): - """ - Update docstring to replace any key words in the text_replacements dict - """ - if not docstring: - return docstring - for key_word, replacement in self.text_replacements.items(): - docstring = docstring.replace(f" {key_word} ", f" {replacement} ") - if "\n" in docstring: - # if multiline docstring, add linebreaks to put the """ on a separate line - docstring = "\n" + docstring + "\n\n" - return docstring - - def visit_FunctionDef(self, node): - """ - Re-use replacement logic for Async functions - """ - return self.visit_AsyncFunctionDef(node) - - def visit_AsyncFunctionDef(self, node): - """ - Replace async functions with sync functions - """ - # replace docstring - docstring = self.update_docstring(ast.get_docstring(node)) - if isinstance(node.body[0], ast.Expr) and isinstance( - node.body[0].value, ast.Str - ): - node.body[0].value.s = docstring - # drop or replace body as needed - if node.name in self.drop_methods: - return None - elif node.name in self.pass_methods: - # keep only docstring in pass mode - node.body = [ast.Expr(value=ast.Str(s=docstring))] - elif node.name in self.error_methods: - self._create_error_node(node, "Function not implemented in sync class") - elif node.name in self.replace_methods: - # replace function body with custom code - new_body = [] - for line in self.replace_methods[node.name].split("\n"): - parsed = ast.parse(line) - if len(parsed.body) > 0: - new_body.append(parsed.body[0]) - node.body = new_body - else: - # check if the function contains non-replaced usage of asyncio - func_ast = ast.parse(ast.unparse(node)) - for n in ast.walk(func_ast): - if isinstance(n, ast.Call) \ - and isinstance(n.func, ast.Attribute) \ - and isinstance(n.func.value, ast.Name) \ - and n.func.value.id == "asyncio" \ - and n.func.attr not in self.asyncio_replacements: - path_str = f"{self.name}.{node.name}" if self.name else node.name - print(f"{path_str} contains unhandled asyncio calls: {n.func.attr}. Add method to drop_methods, pass_methods, or error_methods to handle safely.") - # remove pytest.mark.asyncio decorator - if hasattr(node, "decorator_list"): - # TODO: make generic - new_list = [] - for decorator in node.decorator_list: - # check for @CrossSync.x() decorators - if "CrossSync" in ast.dump(decorator): - if "rename_sync" in ast.dump(decorator): - new_name = decorator.args[0].value - node.name = new_name - elif "drop_method" in ast.dump(decorator): - return None - else: - new_list.append(decorator) - node.decorator_list = new_list - is_asyncio_decorator = lambda d: all(x in ast.dump(d) for x in ["pytest", "mark", "asyncio"]) - node.decorator_list = [ - d for d in node.decorator_list if not is_asyncio_decorator(d) - ] - - # visit string type annotations - for arg in node.args.args: - if arg.annotation: - if isinstance(arg.annotation, ast.Constant): - arg.annotation.value = self.text_replacements.get(arg.annotation.value, arg.annotation.value) - return ast.copy_location( - ast.FunctionDef( - self.text_replacements.get(node.name, node.name), - self.visit(node.args), - [self.visit(stmt) for stmt in node.body], - [self.visit(stmt) for stmt in node.decorator_list], - node.returns and self.visit(node.returns), - ), - node, - ) - - def visit_Call(self, node): - return ast.copy_location( - ast.Call( - self.visit(node.func), - [self.visit(arg) for arg in node.args], - [self.visit(keyword) for keyword in node.keywords], - ), - node, - ) - - def visit_Await(self, node): - return self.visit(node.value) - - def visit_Attribute(self, node): - if ( - isinstance(node.value, ast.Name) - and isinstance(node.value.ctx, ast.Load) - and node.value.id == "asyncio" - and node.attr in self.asyncio_replacements - ): - replacement = self.asyncio_replacements[node.attr] - return ast.copy_location(ast.parse(replacement, mode="eval").body, node) - fixed = ast.copy_location( - ast.Attribute( - self.visit(node.value), - self.text_replacements.get(node.attr, node.attr), # replace attr value - node.ctx - ), node - ) - return fixed - - def visit_Name(self, node): - node.id = self.text_replacements.get(node.id, node.id) - return node - - def visit_AsyncFor(self, node): - return ast.copy_location( - ast.For( - self.visit(node.target), - self.visit(node.iter), - [self.visit(stmt) for stmt in node.body], - [self.visit(stmt) for stmt in node.orelse], - ), - node, - ) - - def visit_AsyncWith(self, node): - return ast.copy_location( - ast.With( - [self.visit(item) for item in node.items], - [self.visit(stmt) for stmt in node.body], - ), - node, - ) - - def visit_ListComp(self, node): - # replace [x async for ...] with [x for ...] - new_generators = [] - for generator in node.generators: - if generator.is_async: - new_generators.append( - ast.copy_location( - ast.comprehension( - self.visit(generator.target), - self.visit(generator.iter), - [self.visit(i) for i in generator.ifs], - False, - ), - generator, - ) - ) - else: - new_generators.append(generator) - node.generators = new_generators - return ast.copy_location( - ast.ListComp( - self.visit(node.elt), - [self.visit(gen) for gen in node.generators], - ), - node, - ) - - def visit_Subscript(self, node): - if ( - hasattr(node, "value") - and isinstance(node.value, ast.Name) - and self.text_replacements.get(node.value.id, False) is None - ): - # needed for Awaitable - return self.visit(node.slice) - return ast.copy_location( - ast.Subscript( - self.visit(node.value), - self.visit(node.slice), - node.ctx, - ), - node, - ) - - @staticmethod - def _create_error_node(node, error_msg): - # replace function body with NotImplementedError - exc_node = ast.Call( - func=ast.Name(id="NotImplementedError", ctx=ast.Load()), - args=[ast.Str(s=error_msg)], - keywords=[], - ) - raise_node = ast.Raise(exc=exc_node, cause=None) - node.body = [raise_node] - - def get_imports(self, filename): - """ - Extract all imports from file root - - Include if statements that contain imports - """ - with open(filename, "r") as f: - full_tree = ast.parse(f.read(), filename) - imports = [node for node in full_tree.body if isinstance(node, (ast.Import, ast.ImportFrom))] - if_imports = [self.visit(node) for node in full_tree.body if isinstance(node, ast.If) and any(isinstance(n, (ast.Import, ast.ImportFrom)) for n in node.body)] - try_imports = [self.visit(node) for node in full_tree.body if isinstance(node, ast.Try)] - return set(imports + if_imports + try_imports) - - -def transform_class(in_obj: Type, **kwargs): - filename = inspect.getfile(in_obj) - lines, lineno = inspect.getsourcelines(in_obj) - ast_tree = ast.parse(textwrap.dedent("".join(lines)), filename) - new_name = None - if ast_tree.body and isinstance(ast_tree.body[0], ast.ClassDef): - cls_node = ast_tree.body[0] - # remove cross_sync decorator - if hasattr(cls_node, "decorator_list"): - cls_node.decorator_list = [d for d in cls_node.decorator_list if not isinstance(d, ast.Call) or not isinstance(d.func, ast.Attribute) or not isinstance(d.func.value, ast.Name) or d.func.value.id != "CrossSync"] - # update name - old_name = cls_node.name - # set default name for new class if unset - new_name = kwargs.pop("autogen_sync_name", f"{old_name}_SyncGen") - cls_node.name = new_name - ast.increment_lineno(ast_tree, lineno - 1) - # add ABC as base class - # cls_node.bases = ast_tree.body[0].bases + [ - # ast.Name("ABC", ast.Load()), - # ] - # remove top-level imports if any. Add them back later - ast_tree.body = [n for n in ast_tree.body if not isinstance(n, (ast.Import, ast.ImportFrom))] - # transform - transformer = AsyncToSyncTransformer(name=new_name, **kwargs) - transformer.visit(ast_tree) - # find imports - imports = transformer.get_imports(filename) - return ast_tree.body, imports - - -if __name__ == "__main__": - # for load_path in ["./google/cloud/bigtable/data/_sync/sync_gen.yaml", "./google/cloud/bigtable/data/_sync/unit_tests.yaml"]: - # for load_path in ["./google/cloud/bigtable/data/_sync/sync_gen.yaml"]: - # config = yaml.safe_load(Path(load_path).read_text()) - - # save_path = config.get("save_path") - # code = transform_from_config(config) - - # if save_path is not None: - # with open(save_path, "w") as f: - # f.write(code) - # find all classes in the library - lib_root = "google/cloud/bigtable/data/_async" - lib_files = [f"{lib_root}/{f}" for f in os.listdir(lib_root) if f.endswith(".py")] - - test_root = "tests/unit/data/_async" - test_files = [f"{test_root}/{f}" for f in os.listdir(test_root) if f.endswith(".py")] - all_files = lib_files + test_files - - enabled_classes = [] - for file in all_files: - file_module = file.replace("/", ".")[:-3] - for cls_name, cls in inspect.getmembers(importlib.import_module(file_module), inspect.isclass): - # keep only those with CrossSync annotation - if hasattr(cls, "cross_sync_enabled") and not cls in enabled_classes: - enabled_classes.append(cls) - # bucket classes by output location - all_paths = {c.cross_sync_file_path for c in enabled_classes} - class_map = {loc: [c for c in enabled_classes if c.cross_sync_file_path == loc] for loc in all_paths} - # generate sync code for each class - for output_file in class_map.keys(): - # initialize new tree and import list - file_mypy_ignore = set() - combined_tree = ast.parse("") - combined_imports = set() - for async_class in class_map[output_file]: - text_replacements = {"CrossSync": "CrossSync._Sync_Impl", **async_class.cross_sync_replace_symbols} - file_mypy_ignore.update(async_class.cross_sync_mypy_ignore) - tree_body, imports = transform_class(async_class, autogen_sync_name=async_class.cross_sync_class_name, text_replacements=text_replacements) - # update combined data - combined_tree.body.extend(tree_body) - combined_imports.update(imports) - # render tree as string of code - import_unique = list(set([ast.unparse(i) for i in combined_imports])) - import_unique.sort() - google, non_google = [], [] - for i in import_unique: - if "google" in i: - google.append(i) - else: - non_google.append(i) - import_str = "\n".join(non_google + [""] + google) - mypy_ignore_str = ", ".join(file_mypy_ignore) - if mypy_ignore_str: - mypy_ignore_str = f"# mypy: disable-error-code=\"{mypy_ignore_str}\"" - # append clean tree - header = """# Copyright 2024 Google LLC - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - # - # This file is automatically generated by sync_surface_generator.py. Do not edit. - """ - full_code = f"{header}{mypy_ignore_str}\n\n{import_str}\n\n{ast.unparse(combined_tree)}" - full_code = autoflake.fix_code(full_code, remove_all_unused_imports=True) - formatted_code = format_str(full_code, mode=FileMode()) - print(f"saving {[c.cross_sync_class_name for c in class_map[output_file]]} to {output_file}...") - with open(output_file, "w") as f: - f.write(formatted_code) - - From 73c6e2fde1a077da1848b7398b7c19c07ceb0af8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 15:02:19 -0700 Subject: [PATCH 150/360] remove sync pooled generator --- .../bigtable_v2/services/bigtable/client.py | 2 - .../bigtable/transports/pooled_grpc.py | 445 ------------------ 2 files changed, 447 deletions(-) delete mode 100644 google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py diff --git a/google/cloud/bigtable_v2/services/bigtable/client.py b/google/cloud/bigtable_v2/services/bigtable/client.py index 4a380651d..7eda705b9 100644 --- a/google/cloud/bigtable_v2/services/bigtable/client.py +++ b/google/cloud/bigtable_v2/services/bigtable/client.py @@ -56,7 +56,6 @@ from .transports.grpc import BigtableGrpcTransport from .transports.grpc_asyncio import BigtableGrpcAsyncIOTransport from .transports.pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport -from .transports.pooled_grpc import PooledBigtableGrpcTransport from .transports.rest import BigtableRestTransport @@ -72,7 +71,6 @@ class BigtableClientMeta(type): _transport_registry["grpc"] = BigtableGrpcTransport _transport_registry["grpc_asyncio"] = BigtableGrpcAsyncIOTransport _transport_registry["pooled_grpc_asyncio"] = PooledBigtableGrpcAsyncIOTransport - _transport_registry["pooled_grpc"] = PooledBigtableGrpcTransport _transport_registry["rest"] = BigtableRestTransport def get_transport_class( diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py deleted file mode 100644 index 2c808a000..000000000 --- a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py +++ /dev/null @@ -1,445 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import warnings -from functools import partialmethod -from functools import partial -import time -from typing import ( - Awaitable, - Callable, - Dict, - Optional, - Sequence, - Tuple, - Union, - List, - Type, -) - -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore - -from google.cloud.bigtable_v2.types import bigtable -from .base import BigtableTransport, DEFAULT_CLIENT_INFO -from .grpc import BigtableGrpcTransport - - -class PooledMultiCallable: - def __init__(self, channel_pool: "PooledChannel", *args, **kwargs): - self._init_args = args - self._init_kwargs = kwargs - self.next_channel_fn = channel_pool.next_channel - - def with_call(self, *args, **kwargs): - raise NotImplementedError() - - def future(self, *args, **kwargs): - raise NotImplementedError() - - -class PooledUnaryUnaryMultiCallable(PooledMultiCallable, grpc.UnaryUnaryMultiCallable): - def __call__(self, *args, **kwargs): - return self.next_channel_fn().unary_unary( - *self._init_args, **self._init_kwargs - )(*args, **kwargs) - - -class PooledUnaryStreamMultiCallable( - PooledMultiCallable, grpc.UnaryStreamMultiCallable -): - def __call__(self, *args, **kwargs): - return self.next_channel_fn().unary_stream( - *self._init_args, **self._init_kwargs - )(*args, **kwargs) - - -class PooledStreamUnaryMultiCallable( - PooledMultiCallable, grpc.StreamUnaryMultiCallable -): - def __call__(self, *args, **kwargs): - return self.next_channel_fn().stream_unary( - *self._init_args, **self._init_kwargs - )(*args, **kwargs) - - -class PooledStreamStreamMultiCallable( - PooledMultiCallable, grpc.StreamStreamMultiCallable -): - def __call__(self, *args, **kwargs): - return self.next_channel_fn().stream_stream( - *self._init_args, **self._init_kwargs - )(*args, **kwargs) - - -class PooledChannel(grpc.Channel): - def __init__( - self, - pool_size: int = 3, - host: str = "bigtable.googleapis.com", - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - quota_project_id: Optional[str] = None, - default_scopes: Optional[Sequence[str]] = None, - scopes: Optional[Sequence[str]] = None, - default_host: Optional[str] = None, - insecure: bool = False, - **kwargs, - ): - self._pool: List[grpc.Channel] = [] - self._next_idx = 0 - if insecure: - self._create_channel = partial(grpc.insecure_channel, host) - else: - self._create_channel = partial( - grpc_helpers.create_channel, - target=host, - credentials=credentials, - credentials_file=credentials_file, - quota_project_id=quota_project_id, - default_scopes=default_scopes, - scopes=scopes, - default_host=default_host, - **kwargs, - ) - for i in range(pool_size): - self._pool.append(self._create_channel()) - - def next_channel(self) -> grpc.Channel: - channel = self._pool[self._next_idx] - self._next_idx = (self._next_idx + 1) % len(self._pool) - return channel - - def unary_unary(self, *args, **kwargs) -> grpc.UnaryUnaryMultiCallable: - return PooledUnaryUnaryMultiCallable(self, *args, **kwargs) - - def unary_stream(self, *args, **kwargs) -> grpc.UnaryStreamMultiCallable: - return PooledUnaryStreamMultiCallable(self, *args, **kwargs) - - def stream_unary(self, *args, **kwargs) -> grpc.StreamUnaryMultiCallable: - return PooledStreamUnaryMultiCallable(self, *args, **kwargs) - - def stream_stream(self, *args, **kwargs) -> grpc.StreamStreamMultiCallable: - return PooledStreamStreamMultiCallable(self, *args, **kwargs) - - def close(self): - for channel in self._pool: - channel.close() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: - raise NotImplementedError() - - def wait_for_state_change(self, last_observed_state): - raise NotImplementedError() - - def subscribe( - self, callback, try_to_connect: bool = False - ) -> grpc.ChannelConnectivity: - raise NotImplementedError() - - def unsubscribe(self, callback): - raise NotImplementedError() - - def replace_channel( - self, channel_idx, grace=1, new_channel=None, event=None - ) -> grpc.Channel: - """ - Replaces a channel in the pool with a fresh one. - - The `new_channel` will start processing new requests immidiately, - but the old channel will continue serving existing clients for - `grace` seconds - - Args: - channel_idx(int): the channel index in the pool to replace - grace(Optional[float]): The time to wait for active RPCs to - finish. If a grace period is not specified (by passing None for - grace), all existing RPCs are cancelled immediately. - new_channel(grpc.Channel): a new channel to insert into the pool - at `channel_idx`. If `None`, a new channel will be created. - event(Optional[threading.Event]): an event to signal when the - replacement should be aborted. If set, will call `event.wait()` - instead of the `time.sleep` function. - """ - if channel_idx >= len(self._pool) or channel_idx < 0: - raise ValueError( - f"invalid channel_idx {channel_idx} for pool size {len(self._pool)}" - ) - if new_channel is None: - new_channel = self._create_channel() - old_channel = self._pool[channel_idx] - self._pool[channel_idx] = new_channel - if event: - event.wait(grace) - else: - time.sleep(grace) - old_channel.close() - return new_channel - - -class PooledBigtableGrpcTransport(BigtableGrpcTransport): - """Pooled gRPC backend transport for Bigtable. - - Service for reading from and writing to existing Bigtable - tables. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - - This class allows channel pooling, so multiple channels can be used concurrently - when making requests. Channels are rotated in a round-robin fashion. - """ - - @classmethod - def with_fixed_size(cls, pool_size) -> Type["PooledBigtableGrpcTransport"]: - """ - Creates a new class with a fixed channel pool size. - - A fixed channel pool makes compatibility with other transports easier, - as the initializer signature is the same. - """ - - class PooledTransportFixed(cls): - __init__ = partialmethod(cls.__init__, pool_size=pool_size) - - PooledTransportFixed.__name__ = f"{cls.__name__}_{pool_size}" - PooledTransportFixed.__qualname__ = PooledTransportFixed.__name__ - return PooledTransportFixed - - @classmethod - def create_channel( - cls, - pool_size: int = 3, - host: str = "bigtable.googleapis.com", - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> grpc.Channel: - """Create and return a PooledChannel object, representing a pool of gRPC channels - Args: - pool_size (int): The number of channels in the pool. - host (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - PooledChannel: a channel pool object - """ - - return PooledChannel( - pool_size, - host, - credentials=credentials, - credentials_file=credentials_file, - quota_project_id=quota_project_id, - default_scopes=cls.AUTH_SCOPES, - scopes=scopes, - default_host=cls.DEFAULT_HOST, - **kwargs, - ) - - def __init__( - self, - *, - pool_size: int = 3, - host: str = "bigtable.googleapis.com", - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: - """Instantiate the transport. - - Args: - pool_size (int): the number of grpc channels to maintain in a pool - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or application default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. - client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): - A callback to provide client certificate bytes and private key bytes, - both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - ValueError: if ``pool_size`` <= 0 - """ - if pool_size <= 0: - raise ValueError(f"invalid pool_size: {pool_size}") - self._ssl_channel_credentials = ssl_channel_credentials - self._stubs: Dict[str, Callable] = {} - - if api_mtls_endpoint: - warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) - if client_cert_source: - warnings.warn("client_cert_source is deprecated", DeprecationWarning) - - if api_mtls_endpoint: - host = api_mtls_endpoint - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - self._ssl_channel_credentials = SslCredentials().ssl_credentials - - else: - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - - # The base transport sets the host, credentials and scopes - BigtableTransport.__init__( - self, - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - client_info=client_info, - always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience, - ) - self._quota_project_id = quota_project_id - self._grpc_channel = type(self).create_channel( - pool_size, - self._host, - # use the credentials which are saved - credentials=self._credentials, - # Set ``credentials_file`` to ``None`` here as - # the credentials that we saved earlier should be used. - credentials_file=None, - scopes=self._scopes, - ssl_credentials=self._ssl_channel_credentials, - quota_project_id=self._quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - # Wrap messages. This must be done after self._grpc_channel exists - self._prep_wrapped_messages(client_info) - - @property - def pool_size(self) -> int: - """The number of grpc channels in the pool.""" - return len(self._grpc_channel._pool) - - @property - def channels(self) -> List[grpc.Channel]: - """Acccess the internal list of grpc channels.""" - return self._grpc_channel._pool - - def replace_channel( - self, channel_idx, grace=1, new_channel=None, event=None - ) -> grpc.Channel: - """ - Replaces a channel in the pool with a fresh one. - - The `new_channel` will start processing new requests immidiately, - but the old channel will continue serving existing clients for `grace` seconds - - Args: - channel_idx(int): the channel index in the pool to replace - grace(Optional[float]): The time to wait for active RPCs to - finish. If a grace period is not specified (by passing None for - grace), all existing RPCs are cancelled immediately. - new_channel(grpc.Channel): a new channel to insert into the pool - at `channel_idx`. If `None`, a new channel will be created. - event(Optional[threading.Event]): an event to signal when the - replacement should be aborted. If set, will call `event.wait()` - instead of the `time.sleep` function. - """ - return self._grpc_channel.replace_channel( - channel_idx=channel_idx, grace=grace, new_channel=new_channel, event=event - ) - - -__all__ = ("PooledBigtableGrpcTransport",) From 19112961c7e2d6c1079bd99530b418699cba73b7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 15:26:31 -0700 Subject: [PATCH 151/360] reverted some style changes --- .../bigtable/data/_async/_mutate_rows.py | 19 ++--- .../cloud/bigtable/data/_async/_read_rows.py | 10 +-- google/cloud/bigtable/data/_async/client.py | 70 +++++++++++-------- .../cloud/bigtable/data/_sync/cross_sync.py | 41 +++++++---- tests/system/data/test_system_async.py | 2 +- tests/unit/data/_async/test__read_rows.py | 3 +- .../data/_async/test_read_rows_acceptance.py | 1 + 7 files changed, 88 insertions(+), 58 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 2e7181695..968b07c2a 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -21,13 +21,10 @@ from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries import google.cloud.bigtable_v2.types.bigtable as types_pb +import google.cloud.bigtable.data.exceptions as bt_exceptions from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _retry_exception_factory -from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete -from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup -from google.cloud.bigtable.data.exceptions import RetryExceptionGroup -from google.cloud.bigtable.data.exceptions import FailedMutationEntryError # mutate_rows requests are limited to this number of mutations from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT @@ -112,7 +109,7 @@ def __init__( # RPC level errors *retryable_exceptions, # Entry level errors - _MutateRowsIncomplete, + bt_exceptions._MutateRowsIncomplete, ) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) # Note: _operation could be a raw coroutine, but using a lambda @@ -158,11 +155,15 @@ async def start(self): elif len(exc_list) == 1: cause_exc = exc_list[0] else: - cause_exc = RetryExceptionGroup(exc_list) + cause_exc = bt_exceptions.RetryExceptionGroup(exc_list) entry = self.mutations[idx].entry - all_errors.append(FailedMutationEntryError(idx, entry, cause_exc)) + all_errors.append( + bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc) + ) if all_errors: - raise MutationsExceptionGroup(all_errors, len(self.mutations)) + raise bt_exceptions.MutationsExceptionGroup( + all_errors, len(self.mutations) + ) async def _run_attempt(self): """ @@ -216,7 +217,7 @@ async def _run_attempt(self): # check if attempt succeeded, or needs to be retried if self.remaining_indices: # unfinished work; raise exception to trigger retry - raise _MutateRowsIncomplete + raise bt_exceptions._MutateRowsIncomplete def _handle_entry_error(self, idx: int, exc: Exception): """ diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 9285d5f6f..adc18aa90 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -29,7 +29,9 @@ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _RowSetComplete -from google.cloud.bigtable.data import _helpers +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import _retry_exception_factory from google.api_core import retry as retries from google.api_core.retry import exponential_sleep_generator @@ -90,7 +92,7 @@ def __init__( attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), ): - self.attempt_timeout_gen = _helpers._attempt_timeout_generator( + self.attempt_timeout_gen = _attempt_timeout_generator( attempt_timeout, operation_timeout ) self.operation_timeout = operation_timeout @@ -104,7 +106,7 @@ def __init__( self.request = query._to_pb(table) self.table = table self._predicate = retries.if_exception_type(*retryable_exceptions) - self._metadata = _helpers._make_metadata( + self._metadata = _make_metadata( table.table_name, table.app_profile_id, ) @@ -123,7 +125,7 @@ def start_operation(self) -> CrossSync.Iterable[Row]: self._predicate, exponential_sleep_generator(0.01, 60, multiplier=2), self.operation_timeout, - exception_factory=_helpers._retry_exception_factory, + exception_factory=_retry_exception_factory, ) def _read_rows_attempt(self) -> CrossSync.Iterable[Row]: diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 1c52f83c2..1815caef5 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -56,8 +56,15 @@ from google.cloud.bigtable.data.exceptions import FailedQueryShardError from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup -from google.cloud.bigtable.data import _helpers from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _WarmedInstanceKey +from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT +from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data._helpers import _validate_timeouts +from google.cloud.bigtable.data._helpers import _get_retryable_errors +from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry @@ -145,7 +152,6 @@ def __init__( ValueError: if pool_size is less than 1 """ # set up transport in registry - # TODO: simplify when released: https://github.com/googleapis/gapic-generator-python/pull/1699 transport_str = f"bt-{self._client_version()}-{pool_size}" transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) BigtableClientMeta._transport_registry[transport_str] = transport @@ -183,10 +189,10 @@ def __init__( PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport ) # keep track of active instances to for warmup on channel refresh - self._active_instances: Set[_helpers._WarmedInstanceKey] = set() + self._active_instances: Set[_WarmedInstanceKey] = set() # keep track of table objects associated with each instance # only remove instance from _active_instances when all associated tables remove it - self._instance_owners: dict[_helpers._WarmedInstanceKey, Set[int]] = {} + self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() self._channel_refresh_tasks: list[CrossSync.Task[None]] = [] self._executor = ( @@ -253,7 +259,9 @@ def _start_background_channel_refresh(self) -> None: ) self._channel_refresh_tasks.append(refresh_task) refresh_task.add_done_callback( - lambda _: self._channel_refresh_tasks.remove(refresh_task) if refresh_task in self._channel_refresh_tasks else None + lambda _: self._channel_refresh_tasks.remove(refresh_task) + if refresh_task in self._channel_refresh_tasks + else None ) async def close(self, timeout: float | None = None): @@ -269,7 +277,7 @@ async def close(self, timeout: float | None = None): await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) async def _ping_and_warm_instances( - self, channel: Channel, instance_key: _helpers._WarmedInstanceKey | None = None + self, channel: Channel, instance_key: _WarmedInstanceKey | None = None ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel @@ -380,7 +388,7 @@ async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: owners call _remove_instance_registration """ instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _helpers._WarmedInstanceKey( + instance_key = _WarmedInstanceKey( instance_name, owner.table_name, owner.app_profile_id ) self._instance_owners.setdefault(instance_key, set()).add(id(owner)) @@ -413,7 +421,7 @@ async def _remove_instance_registration( bool: True if instance was removed, else False """ instance_name = self._gapic_client.instance_path(self.project, instance_id) - instance_key = _helpers._WarmedInstanceKey( + instance_key = _WarmedInstanceKey( instance_name, owner.table_name, owner.app_profile_id ) owner_list = self._instance_owners.get(instance_key, set()) @@ -550,15 +558,15 @@ def __init__( # NOTE: any changes to the signature of this method should also be reflected # in client.get_table() # validate timeouts - _helpers._validate_timeouts( + _validate_timeouts( default_operation_timeout, default_attempt_timeout, allow_none=True ) - _helpers._validate_timeouts( + _validate_timeouts( default_read_rows_operation_timeout, default_read_rows_attempt_timeout, allow_none=True, ) - _helpers._validate_timeouts( + _validate_timeouts( default_mutate_rows_operation_timeout, default_mutate_rows_attempt_timeout, allow_none=True, @@ -639,10 +647,10 @@ async def read_rows_stream( from any retries that failed google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error """ - operation_timeout, attempt_timeout = _helpers._get_timeouts( + operation_timeout, attempt_timeout = _get_timeouts( operation_timeout, attempt_timeout, self ) - retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) + retryable_excs = _get_retryable_errors(retryable_errors, self) row_merger = _ReadRowsOperationAsync( query, @@ -790,16 +798,16 @@ async def read_rows_sharded( """ if not sharded_query: raise ValueError("empty sharded_query") - operation_timeout, attempt_timeout = _helpers._get_timeouts( + operation_timeout, attempt_timeout = _get_timeouts( operation_timeout, attempt_timeout, self ) # make sure each rpc stays within overall operation timeout - rpc_timeout_generator = _helpers._attempt_timeout_generator( + rpc_timeout_generator = _attempt_timeout_generator( operation_timeout, operation_timeout ) # limit the number of concurrent requests using a semaphore - concurrency_sem = CrossSync.Semaphore(_helpers._CONCURRENCY_LIMIT) + concurrency_sem = CrossSync.Semaphore(_CONCURRENCY_LIMIT) async def read_rows_with_semaphore(query): async with concurrency_sem: @@ -935,20 +943,20 @@ async def sample_row_keys( google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error """ # prepare timeouts - operation_timeout, attempt_timeout = _helpers._get_timeouts( + operation_timeout, attempt_timeout = _get_timeouts( operation_timeout, attempt_timeout, self ) - attempt_timeout_gen = _helpers._attempt_timeout_generator( + attempt_timeout_gen = _attempt_timeout_generator( attempt_timeout, operation_timeout ) # prepare retryable - retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) + retryable_excs = _get_retryable_errors(retryable_errors, self) predicate = retries.if_exception_type(*retryable_excs) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) # prepare request - metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) + metadata = _make_metadata(self.table_name, self.app_profile_id) async def execute_rpc(): results = await self.client._gapic_client.sample_row_keys( @@ -965,7 +973,7 @@ async def execute_rpc(): predicate, sleep_generator, operation_timeout, - exception_factory=_helpers._retry_exception_factory, + exception_factory=_retry_exception_factory, ) def mutations_batcher( @@ -1058,7 +1066,7 @@ async def mutate_row( safely retried. ValueError: if invalid arguments are provided """ - operation_timeout, attempt_timeout = _helpers._get_timeouts( + operation_timeout, attempt_timeout = _get_timeouts( operation_timeout, attempt_timeout, self ) @@ -1069,7 +1077,7 @@ async def mutate_row( if all(mutation.is_idempotent() for mutation in mutations_list): # mutations are all idempotent and safe to retry predicate = retries.if_exception_type( - *_helpers._get_retryable_errors(retryable_errors, self) + *_get_retryable_errors(retryable_errors, self) ) else: # mutations should not be retried @@ -1084,7 +1092,7 @@ async def mutate_row( table_name=self.table_name, app_profile_id=self.app_profile_id, timeout=attempt_timeout, - metadata=_helpers._make_metadata(self.table_name, self.app_profile_id), + metadata=_make_metadata(self.table_name, self.app_profile_id), retry=None, ) return await CrossSync.retry_target( @@ -1092,7 +1100,7 @@ async def mutate_row( predicate, sleep_generator, operation_timeout, - exception_factory=_helpers._retry_exception_factory, + exception_factory=_retry_exception_factory, ) async def bulk_mutate_rows( @@ -1135,10 +1143,10 @@ async def bulk_mutate_rows( Contains details about any failed entries in .exceptions ValueError: if invalid arguments are provided """ - operation_timeout, attempt_timeout = _helpers._get_timeouts( + operation_timeout, attempt_timeout = _get_timeouts( operation_timeout, attempt_timeout, self ) - retryable_excs = _helpers._get_retryable_errors(retryable_errors, self) + retryable_excs = _get_retryable_errors(retryable_errors, self) operation = _MutateRowsOperationAsync( self.client._gapic_client, @@ -1191,7 +1199,7 @@ async def check_and_mutate_row( Raises: google.api_core.exceptions.GoogleAPIError: exceptions from grpc call """ - operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) + operation_timeout, _ = _get_timeouts(operation_timeout, None, self) if true_case_mutations is not None and not isinstance( true_case_mutations, list ): @@ -1202,7 +1210,7 @@ async def check_and_mutate_row( ): false_case_mutations = [false_case_mutations] false_case_list = [m._to_pb() for m in false_case_mutations or []] - metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) + metadata = _make_metadata(self.table_name, self.app_profile_id) result = await self.client._gapic_client.check_and_mutate_row( true_mutations=true_case_list, false_mutations=false_case_list, @@ -1246,14 +1254,14 @@ async def read_modify_write_row( google.api_core.exceptions.GoogleAPIError: exceptions from grpc call ValueError: if invalid arguments are provided """ - operation_timeout, _ = _helpers._get_timeouts(operation_timeout, None, self) + operation_timeout, _ = _get_timeouts(operation_timeout, None, self) if operation_timeout <= 0: raise ValueError("operation_timeout must be greater than 0") if rules is not None and not isinstance(rules, list): rules = [rules] if not rules: raise ValueError("rules must contain at least one item") - metadata = _helpers._make_metadata(self.table_name, self.app_profile_id) + metadata = _make_metadata(self.table_name, self.app_profile_id) result = await self.client._gapic_client.read_modify_write_row( rules=[rule._to_pb() for rule in rules], row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 4ce65c38e..9d72dedc9 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -41,19 +41,25 @@ T = TypeVar("T") + def pytest_mark_asyncio(func): try: import pytest + return pytest.mark.asyncio(func) except ImportError: return func + def pytest_asyncio_fixture(*args, **kwargs): import pytest_asyncio + def decorator(func): return pytest_asyncio.fixture(*args, **kwargs)(func) + return decorator + class AstDecorator: """ Helper class for CrossSync decorators used for guiding ast transformations. @@ -62,7 +68,13 @@ class AstDecorator: but act as no-ops when encountered in live code """ - def __init__(self, decorator_name, required_keywords=(), inner_decorator=None, **default_kwargs): + def __init__( + self, + decorator_name, + required_keywords=(), + inner_decorator=None, + **default_kwargs, + ): self.name = decorator_name self.required_kwargs = required_keywords self.default_kwargs = default_kwargs @@ -77,15 +89,18 @@ def __call__(self, *args, **kwargs): return self.inner_decorator(*args, **kwargs) if len(args) == 1 and callable(args[0]): return args[0] + def decorator(func): return func + return decorator def parse_ast_keywords(self, node): - got_kwargs = { - kw.arg: self._convert_ast_to_py(kw.value) - for kw in node.keywords - } if hasattr(node, "keywords") else {} + got_kwargs = ( + {kw.arg: self._convert_ast_to_py(kw.value) for kw in node.keywords} + if hasattr(node, "keywords") + else {} + ) for key in got_kwargs.keys(): if key not in self.all_valid_keys: raise ValueError(f"Invalid keyword argument: {key}") @@ -99,6 +114,7 @@ def _convert_ast_to_py(self, ast_node): Helper to convert ast primitives to python primitives. Used when unwrapping kwargs """ import ast + if isinstance(ast_node, ast.Constant): return ast_node.value if isinstance(ast_node, ast.List): @@ -112,12 +128,9 @@ def _convert_ast_to_py(self, ast_node): def _node_eq(self, node: ast.Node): import ast + if "CrossSync" in ast.dump(node): - decorator_type = ( - node.func.attr - if hasattr(node, "func") - else node.attr - ) + decorator_type = node.func.attr if hasattr(node, "func") else node.attr if decorator_type == self.name: return True return False @@ -138,6 +151,7 @@ def __getattr__(self, name): return decorator raise AttributeError(f"CrossSync has no attribute {name}") + class CrossSync(metaclass=_DecoratorMeta): is_async = True @@ -159,8 +173,11 @@ class CrossSync(metaclass=_DecoratorMeta): Generator: TypeAlias = AsyncGenerator _decorators: list[AstDecorator] = [ - AstDecorator("pytest", inner_decorator=pytest_mark_asyncio), # decorate test methods to run with pytest-asyncio - AstDecorator("pytest_fixture", # decorate test methods to run with pytest fixture + AstDecorator( + "pytest", inner_decorator=pytest_mark_asyncio + ), # decorate test methods to run with pytest-asyncio + AstDecorator( + "pytest_fixture", # decorate test methods to run with pytest fixture inner_decorator=pytest_asyncio_fixture, scope="function", params=None, diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index ed3435e39..65b195526 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -34,6 +34,7 @@ TEST_FAMILY = "test-family" TEST_FAMILY_2 = "test-family-2" + class TempRowBuilderAsync: """ Used to add rows to a table for testing purposes. @@ -79,7 +80,6 @@ async def delete_rows(self): class TestSystemAsync: - @CrossSync.pytest_fixture(scope="session") async def client(self): project = os.getenv("GOOGLE_CLOUD_PROJECT") or None diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index e04436ee1..5d9957e1f 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -63,8 +63,9 @@ def test_ctor(self): expected_operation_timeout = 42 expected_request_timeout = 44 time_gen_mock = mock.Mock() + subpath = "_async" if CrossSync.is_async else "_sync" with mock.patch( - "google.cloud.bigtable.data._helpers._attempt_timeout_generator", + f"google.cloud.bigtable.data.{subpath}._read_rows._attempt_timeout_generator", time_gen_mock, ): instance = self._make_one( diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 28d8c56f9..0b844ec3c 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -39,6 +39,7 @@ ) from google.cloud.bigtable.data._sync.client import BigtableDataClient # noqa: F401 + class TestReadRowsAcceptanceAsync: @staticmethod def _get_operation_class(): From e166bbe8a5dd613ccb79b25feb32b6ab842eda2d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 15:26:43 -0700 Subject: [PATCH 152/360] removed left ofer crosssync.drop --- tests/system/data/test_system_async.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 65b195526..4d93fce78 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -94,7 +94,6 @@ async def table(self, client, table_id, instance_id): ) as table: yield table - @CrossSync.drop_method @pytest.fixture(scope="session") def event_loop(self): loop = asyncio.get_event_loop() From 6b244c55b6e50d54e6b98d37ed84bf0d94800da4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 15:31:13 -0700 Subject: [PATCH 153/360] removed else branches of cross sync imports --- .../bigtable/data/_async/_mutate_rows.py | 8 -------- .../cloud/bigtable/data/_async/_read_rows.py | 4 ---- google/cloud/bigtable/data/_async/client.py | 20 ------------------- .../bigtable/data/_async/mutations_batcher.py | 6 ------ google/cloud/bigtable/data/_helpers.py | 5 ++--- .../cloud/bigtable/data/_sync/cross_sync.py | 2 +- 6 files changed, 3 insertions(+), 42 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 968b07c2a..3de8a849e 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -31,9 +31,6 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -if not CrossSync.is_async: - from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto - if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -42,11 +39,6 @@ from google.cloud.bigtable_v2.services.bigtable.async_client import ( BigtableAsyncClient, ) - else: - from google.cloud.bigtable.data._sync.client import Table # noqa: F401 - from google.cloud.bigtable_v2.services.bigtable.client import ( # noqa: F401 - BigtableClient, - ) @dataclass diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index adc18aa90..0a4ebbf55 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -38,14 +38,10 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -if not CrossSync.is_async: - from google.cloud.bigtable.data._async._read_rows import _ResetRow if TYPE_CHECKING: if CrossSync.is_async: from google.cloud.bigtable.data._async.client import TableAsync - else: - from google.cloud.bigtable.data._sync.client import Table # noqa: F401 class _ResetRow(Exception): # noqa: F811 diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 1815caef5..14a505ce7 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -91,26 +91,6 @@ from google.cloud.bigtable_v2.services.bigtable.async_client import ( BigtableAsyncClient, ) -else: - from google.cloud.bigtable.data._sync._mutate_rows import ( # noqa: F401 - _MutateRowsOperation, - ) - from google.cloud.bigtable.data._sync.mutations_batcher import ( # noqa: F401 - MutationsBatcher, - ) - from google.cloud.bigtable.data._sync._read_rows import ( # noqa: F401 - _ReadRowsOperation, - ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( # noqa: F401 - PooledBigtableGrpcTransport, - ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( # noqa: F401 - PooledChannel, - ) - from google.cloud.bigtable_v2.services.bigtable.client import ( # noqa: F401 - BigtableClient, - ) - from typing import Iterable # noqa: F401 if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 07eac0e26..caa35425c 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -36,10 +36,6 @@ if CrossSync.is_async: from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync -else: - from google.cloud.bigtable.data._sync._mutate_rows import ( # noqa: F401 - _MutateRowsOperation, - ) if TYPE_CHECKING: @@ -47,8 +43,6 @@ if CrossSync.is_async: from google.cloud.bigtable.data._async.client import TableAsync - else: - from google.cloud.bigtable.data._sync.client import Table # noqa: F401 class _FlowControlAsync: diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index 4c9247c77..7e7038734 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -29,7 +29,6 @@ if TYPE_CHECKING: import grpc from google.cloud.bigtable.data import TableAsync - from google.cloud.bigtable.data._sync.client import Table """ Helper functions used in various places in the library. @@ -138,7 +137,7 @@ def _retry_exception_factory( def _get_timeouts( operation: float | TABLE_DEFAULT, attempt: float | None | TABLE_DEFAULT, - table: "TableAsync" | "Table", + table: "TableAsync" ) -> tuple[float, float]: """ Convert passed in timeout values to floats, using table defaults if necessary. @@ -209,7 +208,7 @@ def _validate_timeouts( def _get_retryable_errors( call_codes: Sequence["grpc.StatusCode" | int | type[Exception]] | TABLE_DEFAULT, - table: "TableAsync" | "Table", + table: "TableAsync", ) -> list[type[Exception]]: """ Convert passed in retryable error codes to a list of exception types. diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 9d72dedc9..913818dad 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -126,7 +126,7 @@ def _convert_ast_to_py(self, ast_node): } raise ValueError(f"Unsupported type {type(ast_node)}") - def _node_eq(self, node: ast.Node): + def _node_eq(self, node): import ast if "CrossSync" in ast.dump(node): From 9e1afc3481115829969a644e884890e0b9cf8d81 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 11 Jul 2024 15:33:28 -0700 Subject: [PATCH 154/360] fixed mypy error --- google/cloud/bigtable/data/_sync/cross_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 913818dad..fa563955f 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -52,7 +52,7 @@ def pytest_mark_asyncio(func): def pytest_asyncio_fixture(*args, **kwargs): - import pytest_asyncio + import pytest_asyncio # type: ignore def decorator(func): return pytest_asyncio.fixture(*args, **kwargs)(func) From 2a466300175bf87175c72551b082cb19d74fd817 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 11:02:49 -0700 Subject: [PATCH 155/360] added back file --- tests/unit/data/test_read_rows_acceptance.py | 354 +++++++++++++++++++ 1 file changed, 354 insertions(+) create mode 100644 tests/unit/data/test_read_rows_acceptance.py diff --git a/tests/unit/data/test_read_rows_acceptance.py b/tests/unit/data/test_read_rows_acceptance.py new file mode 100644 index 000000000..0b844ec3c --- /dev/null +++ b/tests/unit/data/test_read_rows_acceptance.py @@ -0,0 +1,354 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import warnings +import pytest +import mock +import proto + +from itertools import zip_longest + +from google.cloud.bigtable_v2 import ReadRowsResponse + +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.row import Row + +from ...v2_client.test_row_merger import ReadRowsTest, TestFile + +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +if CrossSync.is_async: + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + from google.cloud.bigtable.data._async.client import BigtableDataClientAsync +else: + from google.cloud.bigtable.data._sync._read_rows import ( # noqa: F401 + _ReadRowsOperation, + ) + from google.cloud.bigtable.data._sync.client import BigtableDataClient # noqa: F401 + + +class TestReadRowsAcceptanceAsync: + @staticmethod + def _get_operation_class(): + return _ReadRowsOperationAsync + + @staticmethod + def _get_client_class(): + return BigtableDataClientAsync + + def parse_readrows_acceptance_tests(): + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "../read-rows-acceptance-test.json") + + with open(filename) as json_file: + test_json = TestFile.from_json(json_file.read()) + return test_json.read_rows_tests + + @staticmethod + def extract_results_from_row(row: Row): + results = [] + for family, col, cells in row.items(): + for cell in cells: + results.append( + ReadRowsTest.Result( + row_key=row.row_key, + family_name=family, + qualifier=col, + timestamp_micros=cell.timestamp_ns // 1000, + value=cell.value, + label=(cell.labels[0] if cell.labels else ""), + ) + ) + return results + + @staticmethod + async def _coro_wrapper(stream): + return stream + + async def _process_chunks(self, *chunks): + async def _row_stream(): + yield ReadRowsResponse(chunks=chunks) + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + results = [] + async for row in merger: + results.append(row) + return results + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + @CrossSync.pytest + async def test_row_merger_scenario(self, test_case: ReadRowsTest): + async def _scenerio_stream(): + for chunk in test_case.chunks: + yield ReadRowsResponse(chunks=[chunk]) + + try: + results = [] + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_scenerio_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + async for row in merger: + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + @CrossSync.pytest + async def test_read_rows_scenario(self, test_case: ReadRowsTest): + async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list): + self.chunk_list = chunk_list + self.idx = -1 + + def __aiter__(self): + return self + + def __iter__(self): + return self + + async def __anext__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + chunk = self.chunk_list[self.idx] + return ReadRowsResponse(chunks=[chunk]) + raise CrossSync.StopIteration + + def __next__(self): + return self.__anext__() + + def cancel(self): + pass + + return mock_stream(chunk_list) + + with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # use emulator mode to avoid auth issues in CI + client = self._get_client_class()() + try: + table = client.get_table("instance", "table") + results = [] + with mock.patch.object( + table.client._gapic_client, "read_rows" + ) as read_rows: + # run once, then return error on retry + read_rows.return_value = _make_gapic_stream(test_case.chunks) + async for row in await table.read_rows_stream(query={}): + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + finally: + await client.close() + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + @CrossSync.pytest + async def test_out_of_order_rows(self): + async def _row_stream(): + yield ReadRowsResponse(last_scanned_row_key=b"a") + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = b"b" + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + with pytest.raises(InvalidChunk): + async for _ in merger: + pass + + @CrossSync.pytest + async def test_bare_reset(self): + first_chunk = ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk( + row_key=b"a", family_name="f", qualifier=b"q", value=b"v" + ) + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, family_name="f") + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, value=b"v") + ), + ) + + @CrossSync.pytest + async def test_missing_family(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + qualifier=b"q", + timestamp_micros=1000, + value=b"v", + commit_row=True, + ) + ) + + @CrossSync.pytest + async def test_mid_cell_row_key_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), + ) + + @CrossSync.pytest + async def test_mid_cell_family_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + family_name="f2", value=b"v", commit_row=True + ), + ) + + @CrossSync.pytest + async def test_mid_cell_qualifier_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + qualifier=b"q2", value=b"v", commit_row=True + ), + ) + + @CrossSync.pytest + async def test_mid_cell_timestamp_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + timestamp_micros=2000, value=b"v", commit_row=True + ), + ) + + @CrossSync.pytest + async def test_mid_cell_labels_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), + ) From 26916862c44358cd4a0f67dfd720674765ef2a68 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 11:04:02 -0700 Subject: [PATCH 156/360] removed file --- tests/unit/data/test_read_rows_acceptance.py | 354 ------------------- 1 file changed, 354 deletions(-) delete mode 100644 tests/unit/data/test_read_rows_acceptance.py diff --git a/tests/unit/data/test_read_rows_acceptance.py b/tests/unit/data/test_read_rows_acceptance.py deleted file mode 100644 index 0b844ec3c..000000000 --- a/tests/unit/data/test_read_rows_acceptance.py +++ /dev/null @@ -1,354 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import os -import warnings -import pytest -import mock -import proto - -from itertools import zip_longest - -from google.cloud.bigtable_v2 import ReadRowsResponse - -from google.cloud.bigtable.data.exceptions import InvalidChunk -from google.cloud.bigtable.data.row import Row - -from ...v2_client.test_row_merger import ReadRowsTest, TestFile - -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -if CrossSync.is_async: - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync -else: - from google.cloud.bigtable.data._sync._read_rows import ( # noqa: F401 - _ReadRowsOperation, - ) - from google.cloud.bigtable.data._sync.client import BigtableDataClient # noqa: F401 - - -class TestReadRowsAcceptanceAsync: - @staticmethod - def _get_operation_class(): - return _ReadRowsOperationAsync - - @staticmethod - def _get_client_class(): - return BigtableDataClientAsync - - def parse_readrows_acceptance_tests(): - dirname = os.path.dirname(__file__) - filename = os.path.join(dirname, "../read-rows-acceptance-test.json") - - with open(filename) as json_file: - test_json = TestFile.from_json(json_file.read()) - return test_json.read_rows_tests - - @staticmethod - def extract_results_from_row(row: Row): - results = [] - for family, col, cells in row.items(): - for cell in cells: - results.append( - ReadRowsTest.Result( - row_key=row.row_key, - family_name=family, - qualifier=col, - timestamp_micros=cell.timestamp_ns // 1000, - value=cell.value, - label=(cell.labels[0] if cell.labels else ""), - ) - ) - return results - - @staticmethod - async def _coro_wrapper(stream): - return stream - - async def _process_chunks(self, *chunks): - async def _row_stream(): - yield ReadRowsResponse(chunks=chunks) - - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = None - chunker = self._get_operation_class().chunk_stream( - instance, self._coro_wrapper(_row_stream()) - ) - merger = self._get_operation_class().merge_rows(chunker) - results = [] - async for row in merger: - results.append(row) - return results - - @pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description - ) - @CrossSync.pytest - async def test_row_merger_scenario(self, test_case: ReadRowsTest): - async def _scenerio_stream(): - for chunk in test_case.chunks: - yield ReadRowsResponse(chunks=[chunk]) - - try: - results = [] - instance = mock.Mock() - instance._last_yielded_row_key = None - instance._remaining_count = None - chunker = self._get_operation_class().chunk_stream( - instance, self._coro_wrapper(_scenerio_stream()) - ) - merger = self._get_operation_class().merge_rows(chunker) - async for row in merger: - for cell in row: - cell_result = ReadRowsTest.Result( - row_key=cell.row_key, - family_name=cell.family, - qualifier=cell.qualifier, - timestamp_micros=cell.timestamp_micros, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - @pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description - ) - @CrossSync.pytest - async def test_read_rows_scenario(self, test_case: ReadRowsTest): - async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): - from google.cloud.bigtable_v2 import ReadRowsResponse - - class mock_stream: - def __init__(self, chunk_list): - self.chunk_list = chunk_list - self.idx = -1 - - def __aiter__(self): - return self - - def __iter__(self): - return self - - async def __anext__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - chunk = self.chunk_list[self.idx] - return ReadRowsResponse(chunks=[chunk]) - raise CrossSync.StopIteration - - def __next__(self): - return self.__anext__() - - def cancel(self): - pass - - return mock_stream(chunk_list) - - with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # use emulator mode to avoid auth issues in CI - client = self._get_client_class()() - try: - table = client.get_table("instance", "table") - results = [] - with mock.patch.object( - table.client._gapic_client, "read_rows" - ) as read_rows: - # run once, then return error on retry - read_rows.return_value = _make_gapic_stream(test_case.chunks) - async for row in await table.read_rows_stream(query={}): - for cell in row: - cell_result = ReadRowsTest.Result( - row_key=cell.row_key, - family_name=cell.family, - qualifier=cell.qualifier, - timestamp_micros=cell.timestamp_micros, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - finally: - await client.close() - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - @CrossSync.pytest - async def test_out_of_order_rows(self): - async def _row_stream(): - yield ReadRowsResponse(last_scanned_row_key=b"a") - - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = b"b" - chunker = self._get_operation_class().chunk_stream( - instance, self._coro_wrapper(_row_stream()) - ) - merger = self._get_operation_class().merge_rows(chunker) - with pytest.raises(InvalidChunk): - async for _ in merger: - pass - - @CrossSync.pytest - async def test_bare_reset(self): - first_chunk = ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk( - row_key=b"a", family_name="f", qualifier=b"q", value=b"v" - ) - ) - with pytest.raises(InvalidChunk): - await self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") - ), - ) - with pytest.raises(InvalidChunk): - await self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, family_name="f") - ), - ) - with pytest.raises(InvalidChunk): - await self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") - ), - ) - with pytest.raises(InvalidChunk): - await self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) - ), - ) - with pytest.raises(InvalidChunk): - await self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) - ), - ) - with pytest.raises(InvalidChunk): - await self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, value=b"v") - ), - ) - - @CrossSync.pytest - async def test_missing_family(self): - with pytest.raises(InvalidChunk): - await self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - qualifier=b"q", - timestamp_micros=1000, - value=b"v", - commit_row=True, - ) - ) - - @CrossSync.pytest - async def test_mid_cell_row_key_change(self): - with pytest.raises(InvalidChunk): - await self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), - ) - - @CrossSync.pytest - async def test_mid_cell_family_change(self): - with pytest.raises(InvalidChunk): - await self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - family_name="f2", value=b"v", commit_row=True - ), - ) - - @CrossSync.pytest - async def test_mid_cell_qualifier_change(self): - with pytest.raises(InvalidChunk): - await self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - qualifier=b"q2", value=b"v", commit_row=True - ), - ) - - @CrossSync.pytest - async def test_mid_cell_timestamp_change(self): - with pytest.raises(InvalidChunk): - await self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - timestamp_micros=2000, value=b"v", commit_row=True - ), - ) - - @CrossSync.pytest - async def test_mid_cell_labels_change(self): - with pytest.raises(InvalidChunk): - await self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), - ) From f48604e8ca92a85dd6add8ff4a4421dff4a09c89 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 11:17:09 -0700 Subject: [PATCH 157/360] moved non-async helper classes out of async folder --- google/cloud/bigtable/data/_async/_mutate_rows.py | 15 +-------------- google/cloud/bigtable/data/_async/_read_rows.py | 6 +----- google/cloud/bigtable/data/_helpers.py | 2 +- google/cloud/bigtable/data/exceptions.py | 15 +++++++++++++++ google/cloud/bigtable/data/mutations.py | 12 ++++++++++++ tests/system/data/__init__.py | 3 +++ tests/system/data/test_system_async.py | 5 ++--- 7 files changed, 35 insertions(+), 23 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 3de8a849e..7c40b492c 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -15,12 +15,10 @@ from __future__ import annotations from typing import Sequence, TYPE_CHECKING -from dataclasses import dataclass import functools from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries -import google.cloud.bigtable_v2.types.bigtable as types_pb import google.cloud.bigtable.data.exceptions as bt_exceptions from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _attempt_timeout_generator @@ -28,6 +26,7 @@ # mutate_rows requests are limited to this number of mutations from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable.data.mutations import _EntryWithProto from google.cloud.bigtable.data._sync.cross_sync import CrossSync @@ -41,16 +40,6 @@ ) -@dataclass -class _EntryWithProto: # noqa: F811 - """ - A dataclass to hold a RowMutationEntry and its corresponding proto representation. - """ - - entry: RowMutationEntry - proto: types_pb.MutateRowsRequest.Entry - - class _MutateRowsOperationAsync: """ MutateRowsOperation manages the logic of sending a set of row mutations, @@ -104,8 +93,6 @@ def __init__( bt_exceptions._MutateRowsIncomplete, ) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - # Note: _operation could be a raw coroutine, but using a lambda - # wrapper helps unify with sync code self._operation = lambda: CrossSync.retry_target( self._run_attempt, self.is_retryable, diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 0a4ebbf55..dfc9c1adb 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -29,6 +29,7 @@ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _RowSetComplete +from google.cloud.bigtable.data.exceptions import _ResetRow from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory @@ -44,11 +45,6 @@ from google.cloud.bigtable.data._async.client import TableAsync -class _ResetRow(Exception): # noqa: F811 - def __init__(self, chunk): - self.chunk = chunk - - class _ReadRowsOperationAsync: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index 7e7038734..a8113cc4a 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -137,7 +137,7 @@ def _retry_exception_factory( def _get_timeouts( operation: float | TABLE_DEFAULT, attempt: float | None | TABLE_DEFAULT, - table: "TableAsync" + table: "TableAsync", ) -> tuple[float, float]: """ Convert passed in timeout values to floats, using table defaults if necessary. diff --git a/google/cloud/bigtable/data/exceptions.py b/google/cloud/bigtable/data/exceptions.py index 8d97640aa..8065ed9d1 100644 --- a/google/cloud/bigtable/data/exceptions.py +++ b/google/cloud/bigtable/data/exceptions.py @@ -41,6 +41,21 @@ class _RowSetComplete(Exception): pass +class _ResetRow(Exception): # noqa: F811 + """ + Internal exception for _ReadRowsOperation + + Denotes that the server sent a reset_row marker, telling the client to drop + all previous chunks for row_key and re-read from the beginning. + + Args: + chunk: the reset_row chunk + """ + + def __init__(self, chunk): + self.chunk = chunk + + class _MutateRowsIncomplete(RuntimeError): """ Exception raised when a mutate_rows call has unfinished work. diff --git a/google/cloud/bigtable/data/mutations.py b/google/cloud/bigtable/data/mutations.py index 335a15e12..2f4e441ed 100644 --- a/google/cloud/bigtable/data/mutations.py +++ b/google/cloud/bigtable/data/mutations.py @@ -366,3 +366,15 @@ def _from_dict(cls, input_dict: dict[str, Any]) -> RowMutationEntry: Mutation._from_dict(mutation) for mutation in input_dict["mutations"] ], ) + + +@dataclass +class _EntryWithProto: + """ + A dataclass to hold a RowMutationEntry and its corresponding proto representation. + + Used in _MutateRowsOperation to avoid repeated conversion of RowMutationEntry to proto. + """ + + entry: RowMutationEntry + proto: types_pb.MutateRowsRequest.Entry diff --git a/tests/system/data/__init__.py b/tests/system/data/__init__.py index 89a37dc92..f2952b2cd 100644 --- a/tests/system/data/__init__.py +++ b/tests/system/data/__init__.py @@ -13,3 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +TEST_FAMILY = "test-family" +TEST_FAMILY_2 = "test-family-2" diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 4d93fce78..95c14cf12 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -25,15 +25,14 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from . import TEST_FAMILY, TEST_FAMILY_2 + if CrossSync.is_async: from google.cloud.bigtable.data._async.client import BigtableDataClientAsync else: from google.cloud.bigtable.data._sync.client import BigtableDataClient from .test_system_async import TEST_FAMILY, TEST_FAMILY_2 -TEST_FAMILY = "test-family" -TEST_FAMILY_2 = "test-family-2" - class TempRowBuilderAsync: """ From a61c54fc854ae2dafff6d5e14d40d0d0ee8835fc Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 11:27:02 -0700 Subject: [PATCH 158/360] fixed lint --- tests/system/data/test_system_async.py | 4 ---- tests/unit/data/_async/test_mutations_batcher.py | 8 -------- tests/unit/data/_async/test_read_rows_acceptance.py | 1 - 3 files changed, 13 deletions(-) diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 95c14cf12..cbd8d7605 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest -import pytest_asyncio import asyncio import uuid import os @@ -29,9 +28,6 @@ if CrossSync.is_async: from google.cloud.bigtable.data._async.client import BigtableDataClientAsync -else: - from google.cloud.bigtable.data._sync.client import BigtableDataClient - from .test_system_async import TEST_FAMILY, TEST_FAMILY_2 class TempRowBuilderAsync: diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 6386adc7f..a26dcfd64 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -272,14 +272,6 @@ async def test_add_to_flow_max_mutation_limits( Test flow control running up against the max API limit Should submit request early, even if the flow control has room for more """ - async_patch = mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", - max_limit, - ) - sync_patch = mock.patch( - "google.cloud.bigtable.data._sync.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", - max_limit, - ) subpath = "_async" if CrossSync.is_async else "_sync" path = f"google.cloud.bigtable.data.{subpath}.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT" with mock.patch(path, max_limit): diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 0b844ec3c..15b181637 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -17,7 +17,6 @@ import warnings import pytest import mock -import proto from itertools import zip_longest From 8b379c881505c4057cb9bc4f142782d94484da90 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 11:33:24 -0700 Subject: [PATCH 159/360] improve version string calculation --- google/cloud/bigtable/data/_async/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 14a505ce7..0dfa7fdc8 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -210,10 +210,10 @@ def _client_version() -> str: """ Helper function to return the client version string for this client """ + version_str = f"{google.cloud.bigtable.__version__}-data" if CrossSync.is_async: - return f"{google.cloud.bigtable.__version__}-data-async" - else: - return f"{google.cloud.bigtable.__version__}-data" + version_str += "-async" + return version_str def _start_background_channel_refresh(self) -> None: """ From 14259e23ba16442857c15e3597eba0c3282ca881 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 11:38:39 -0700 Subject: [PATCH 160/360] created method for event loop verification --- google/cloud/bigtable/data/_async/client.py | 5 ++--- google/cloud/bigtable/data/_sync/cross_sync.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 0dfa7fdc8..a998fd750 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -227,9 +227,8 @@ def _start_background_channel_refresh(self) -> None: and not self._emulator_host and not self._is_closed.is_set() ): - if CrossSync.is_async: - # raise error if not in an event loop - asyncio.get_running_loop() + # raise error if not in an event loop in async client + CrossSync.verify_async_event_loop() for channel_idx in range(self.transport.pool_size): refresh_task = CrossSync.create_task( self._manage_channel, diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index fa563955f..d4e7e985a 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -301,6 +301,13 @@ async def yield_to_event_loop() -> None: """ await asyncio.sleep(0) + @staticmethod + def verify_async_event_loop() -> None: + """ + Raises RuntimeError if the event loop is not running + """ + asyncio.get_running_loop() + class _Sync_Impl: is_async = False @@ -406,3 +413,7 @@ def create_task( @staticmethod def yield_to_event_loop() -> None: pass + + @staticmethod + def verify_async_event_loop() -> None: + pass From 9b7c1e2f0f77b9dac96fcaded2f7cf0b11594081 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 11:54:28 -0700 Subject: [PATCH 161/360] reverted some behavior --- google/cloud/bigtable/data/_async/client.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index a998fd750..d1d349bca 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -237,13 +237,8 @@ def _start_background_channel_refresh(self) -> None: task_name=f"{self.__class__.__name__} channel refresh {channel_idx}", ) self._channel_refresh_tasks.append(refresh_task) - refresh_task.add_done_callback( - lambda _: self._channel_refresh_tasks.remove(refresh_task) - if refresh_task in self._channel_refresh_tasks - else None - ) - async def close(self, timeout: float | None = None): + async def close(self, timeout: float | None = 2.0): """ Cancel all background tasks """ @@ -254,6 +249,7 @@ async def close(self, timeout: float | None = None): if self._executor: self._executor.shutdown(wait=False) await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) + self._channel_refresh_tasks = [] async def _ping_and_warm_instances( self, channel: Channel, instance_key: _WarmedInstanceKey | None = None From 88fda0d58bfa8f5b0dc5217f7ecc559a7b408310 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 11:54:40 -0700 Subject: [PATCH 162/360] added comments --- google/cloud/bigtable/data/_async/client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index d1d349bca..a2ac7f2cc 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -329,9 +329,12 @@ async def _manage_channel( # continuously refresh the channel every `refresh_interval` seconds while not self._is_closed.is_set(): await CrossSync.event_wait( - self._is_closed, next_sleep, async_break_early=False + self._is_closed, + next_sleep, + async_break_early=False, # no need to interrupt sleep. Task will be cancelled on close ) if self._is_closed.is_set(): + # don't refresh if client is closed break # prepare new channel for use new_channel = self.transport.grpc_channel._create_channel() From c3787cabea114dda24f527966bec7aec05a8f3e4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 12:16:26 -0700 Subject: [PATCH 163/360] added comments --- .../cloud/bigtable/data/_sync/cross_sync.py | 79 ++++++++++++++++--- 1 file changed, 68 insertions(+), 11 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index d4e7e985a..82e708b70 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -43,6 +43,12 @@ def pytest_mark_asyncio(func): + """ + Applies pytest.mark.asyncio to a function if pytest is installed, otherwise + returns the function as is + + Used to support CrossSync.pytest decorator, without requiring pytest to be installed + """ try: import pytest @@ -52,6 +58,12 @@ def pytest_mark_asyncio(func): def pytest_asyncio_fixture(*args, **kwargs): + """ + Applies pytest.fixture to a function if pytest is installed, otherwise + returns the function as is + + Used to support CrossSync.pytest_fixture decorator, without requiring pytest to be installed + """ import pytest_asyncio # type: ignore def decorator(func): @@ -66,36 +78,63 @@ class AstDecorator: These decorators provide arguments that are used during the code generation process, but act as no-ops when encountered in live code + + Args: + attr_name: name of the attribute to attach to the CrossSync class + e.g. pytest for CrossSync.pytest + required_keywords: list of required keyword arguments for the decorator. + If the decorator is used without these arguments, a ValueError is + raised during code generation + async_impl: If given, the async code will apply this decorator to its + wrapped function at runtime. If not given, the decorator will be a no-op + **default_kwargs: any kwargs passed define the valid arguments when using the decorator. + The value of each kwarg is the default value for the argument. """ def __init__( self, - decorator_name, + attr_name, required_keywords=(), - inner_decorator=None, + async_impl=None, **default_kwargs, ): - self.name = decorator_name + self.name = attr_name self.required_kwargs = required_keywords self.default_kwargs = default_kwargs self.all_valid_keys = [*required_keywords, *default_kwargs.keys()] - self.inner_decorator = inner_decorator + self.async_impl = async_impl def __call__(self, *args, **kwargs): + """ + Called when the decorator is used in code. + + Returns a no-op decorator function, or applies the async_impl decorator + """ + # raise error if invalid kwargs are passed for kwarg in kwargs: if kwarg not in self.all_valid_keys: raise ValueError(f"Invalid keyword argument: {kwarg}") - if self.inner_decorator: - return self.inner_decorator(*args, **kwargs) + # if async_impl is provided, use the given decorator function + if self.async_impl: + return self.async_impl(**{**self.default_kwargs, **kwargs}) + # if no arguments, args[0] will hold the function to be decorated + # return the function as is if len(args) == 1 and callable(args[0]): return args[0] + # if arguments are provided, return a no-op decorator function def decorator(func): return func return decorator def parse_ast_keywords(self, node): + """ + When this decorator is encountered in the ast during sync generation, parse the + keyword arguments back from ast nodes to python primitives + + Return a full set of kwargs, using default values for missing arguments + """ got_kwargs = ( {kw.arg: self._convert_ast_to_py(kw.value) for kw in node.keywords} if hasattr(node, "keywords") @@ -127,6 +166,9 @@ def _convert_ast_to_py(self, ast_node): raise ValueError(f"Unsupported type {type(ast_node)}") def _node_eq(self, node): + """ + Check if the given ast node is a call to this decorator + """ import ast if "CrossSync" in ast.dump(node): @@ -136,6 +178,9 @@ def _node_eq(self, node): return False def __eq__(self, other): + """ + Helper to support == comparison with ast nodes + """ return self._node_eq(other) @@ -153,8 +198,10 @@ def __getattr__(self, name): class CrossSync(metaclass=_DecoratorMeta): + # support CrossSync.is_async to check if the current environment is async is_async = True + # provide aliases for common async functions and types sleep = asyncio.sleep retry_target = retries.retry_target_async retry_target_stream = retries.retry_target_stream_async @@ -166,7 +213,7 @@ class CrossSync(metaclass=_DecoratorMeta): Event: TypeAlias = asyncio.Event Semaphore: TypeAlias = asyncio.Semaphore StopIteration: TypeAlias = StopAsyncIteration - # type annotations + # provide aliases for common async type annotations Awaitable: TypeAlias = typing.Awaitable Iterable: TypeAlias = AsyncIterable Iterator: TypeAlias = AsyncIterator @@ -174,11 +221,11 @@ class CrossSync(metaclass=_DecoratorMeta): _decorators: list[AstDecorator] = [ AstDecorator( - "pytest", inner_decorator=pytest_mark_asyncio + "pytest", async_impl=pytest_mark_asyncio ), # decorate test methods to run with pytest-asyncio AstDecorator( "pytest_fixture", # decorate test methods to run with pytest fixture - inner_decorator=pytest_asyncio_fixture, + async_impl=pytest_asyncio_fixture, scope="function", params=None, autouse=False, @@ -189,6 +236,9 @@ class CrossSync(metaclass=_DecoratorMeta): @classmethod def Mock(cls, *args, **kwargs): + """ + Alias for AsyncMock, importing at runtime to avoid hard dependency on mock + """ try: from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER @@ -309,6 +359,9 @@ def verify_async_event_loop() -> None: asyncio.get_running_loop() class _Sync_Impl: + """ + Provide sync versions of the async functions and types in CrossSync + """ is_async = False sleep = time.sleep @@ -328,8 +381,6 @@ class _Sync_Impl: Iterator: TypeAlias = typing.Iterator Generator: TypeAlias = typing.Generator - generated_replacements: dict[type, str] = {} - @classmethod def Mock(cls, *args, **kwargs): # try/except added for compatibility with python < 3.8 @@ -412,8 +463,14 @@ def create_task( @staticmethod def yield_to_event_loop() -> None: + """ + No-op for sync version + """ pass @staticmethod def verify_async_event_loop() -> None: + """ + No-op for sync version + """ pass From 7f650634bc96338a5a10c62fbceba965f2d35668 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 13:06:49 -0700 Subject: [PATCH 164/360] removed sync implementation from cross_sync --- .../cloud/bigtable/data/_sync/cross_sync.py | 117 ------------------ 1 file changed, 117 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 82e708b70..b4fc4929d 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -357,120 +357,3 @@ def verify_async_event_loop() -> None: Raises RuntimeError if the event loop is not running """ asyncio.get_running_loop() - - class _Sync_Impl: - """ - Provide sync versions of the async functions and types in CrossSync - """ - is_async = False - - sleep = time.sleep - retry_target = retries.retry_target - retry_target_stream = retries.retry_target_stream - Retry = retries.Retry - Queue: TypeAlias = queue.Queue - Condition: TypeAlias = threading.Condition - Future: TypeAlias = concurrent.futures.Future - Task: TypeAlias = concurrent.futures.Future - Event: TypeAlias = threading.Event - Semaphore: TypeAlias = threading.Semaphore - StopIteration: TypeAlias = StopIteration - # type annotations - Awaitable: TypeAlias = Union[T] - Iterable: TypeAlias = typing.Iterable - Iterator: TypeAlias = typing.Iterator - Generator: TypeAlias = typing.Generator - - @classmethod - def Mock(cls, *args, **kwargs): - # try/except added for compatibility with python < 3.8 - try: - from unittest.mock import Mock - except ImportError: # pragma: NO COVER - from mock import Mock # type: ignore - return Mock(*args, **kwargs) - - @staticmethod - def wait( - futures: Sequence[CrossSync._Sync_Impl.Future[T]], - timeout: float | None = None, - ) -> tuple[ - set[CrossSync._Sync_Impl.Future[T]], set[CrossSync._Sync_Impl.Future[T]] - ]: - """ - abstraction over asyncio.wait - """ - if not futures: - return set(), set() - return concurrent.futures.wait(futures, timeout=timeout) - - @staticmethod - def condition_wait( - condition: CrossSync._Sync_Impl.Condition, timeout: float | None = None - ) -> bool: - """ - returns False if the timeout is reached before the condition is set, otherwise True - """ - return condition.wait(timeout=timeout) - - @staticmethod - def event_wait( - event: CrossSync._Sync_Impl.Event, - timeout: float | None = None, - async_break_early: bool = True, - ) -> None: - event.wait(timeout=timeout) - - @staticmethod - def gather_partials( - partial_list: Sequence[Callable[[], T]], - return_exceptions: bool = False, - sync_executor: concurrent.futures.ThreadPoolExecutor | None = None, - ) -> list[T | BaseException]: - if not partial_list: - return [] - if not sync_executor: - raise ValueError("sync_executor is required for sync version") - futures_list = [sync_executor.submit(partial) for partial in partial_list] - results_list: list[T | BaseException] = [] - for future in futures_list: - found_exc = future.exception() - if found_exc is not None: - if return_exceptions: - results_list.append(found_exc) - else: - raise found_exc - else: - results_list.append(future.result()) - return results_list - - @staticmethod - def create_task( - fn: Callable[..., T], - *fn_args, - sync_executor: concurrent.futures.ThreadPoolExecutor | None = None, - task_name: str | None = None, - **fn_kwargs, - ) -> CrossSync._Sync_Impl.Task[T]: - """ - abstraction over asyncio.create_task. Sync version implemented with threadpool executor - - sync_executor: ThreadPoolExecutor to use for sync operations. Ignored in async version - """ - if not sync_executor: - raise ValueError("sync_executor is required for sync version") - return sync_executor.submit(fn, *fn_args, **fn_kwargs) - - @staticmethod - def yield_to_event_loop() -> None: - """ - No-op for sync version - """ - pass - - @staticmethod - def verify_async_event_loop() -> None: - """ - No-op for sync version - """ - pass From caf27e23d144826e32f74b5e6fc58a5b97f90154 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 13:13:31 -0700 Subject: [PATCH 165/360] added sync_impl --- .../cloud/bigtable/data/_sync/cross_sync.py | 117 ++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index b4fc4929d..82e708b70 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -357,3 +357,120 @@ def verify_async_event_loop() -> None: Raises RuntimeError if the event loop is not running """ asyncio.get_running_loop() + + class _Sync_Impl: + """ + Provide sync versions of the async functions and types in CrossSync + """ + is_async = False + + sleep = time.sleep + retry_target = retries.retry_target + retry_target_stream = retries.retry_target_stream + Retry = retries.Retry + Queue: TypeAlias = queue.Queue + Condition: TypeAlias = threading.Condition + Future: TypeAlias = concurrent.futures.Future + Task: TypeAlias = concurrent.futures.Future + Event: TypeAlias = threading.Event + Semaphore: TypeAlias = threading.Semaphore + StopIteration: TypeAlias = StopIteration + # type annotations + Awaitable: TypeAlias = Union[T] + Iterable: TypeAlias = typing.Iterable + Iterator: TypeAlias = typing.Iterator + Generator: TypeAlias = typing.Generator + + @classmethod + def Mock(cls, *args, **kwargs): + # try/except added for compatibility with python < 3.8 + try: + from unittest.mock import Mock + except ImportError: # pragma: NO COVER + from mock import Mock # type: ignore + return Mock(*args, **kwargs) + + @staticmethod + def wait( + futures: Sequence[CrossSync._Sync_Impl.Future[T]], + timeout: float | None = None, + ) -> tuple[ + set[CrossSync._Sync_Impl.Future[T]], set[CrossSync._Sync_Impl.Future[T]] + ]: + """ + abstraction over asyncio.wait + """ + if not futures: + return set(), set() + return concurrent.futures.wait(futures, timeout=timeout) + + @staticmethod + def condition_wait( + condition: CrossSync._Sync_Impl.Condition, timeout: float | None = None + ) -> bool: + """ + returns False if the timeout is reached before the condition is set, otherwise True + """ + return condition.wait(timeout=timeout) + + @staticmethod + def event_wait( + event: CrossSync._Sync_Impl.Event, + timeout: float | None = None, + async_break_early: bool = True, + ) -> None: + event.wait(timeout=timeout) + + @staticmethod + def gather_partials( + partial_list: Sequence[Callable[[], T]], + return_exceptions: bool = False, + sync_executor: concurrent.futures.ThreadPoolExecutor | None = None, + ) -> list[T | BaseException]: + if not partial_list: + return [] + if not sync_executor: + raise ValueError("sync_executor is required for sync version") + futures_list = [sync_executor.submit(partial) for partial in partial_list] + results_list: list[T | BaseException] = [] + for future in futures_list: + found_exc = future.exception() + if found_exc is not None: + if return_exceptions: + results_list.append(found_exc) + else: + raise found_exc + else: + results_list.append(future.result()) + return results_list + + @staticmethod + def create_task( + fn: Callable[..., T], + *fn_args, + sync_executor: concurrent.futures.ThreadPoolExecutor | None = None, + task_name: str | None = None, + **fn_kwargs, + ) -> CrossSync._Sync_Impl.Task[T]: + """ + abstraction over asyncio.create_task. Sync version implemented with threadpool executor + + sync_executor: ThreadPoolExecutor to use for sync operations. Ignored in async version + """ + if not sync_executor: + raise ValueError("sync_executor is required for sync version") + return sync_executor.submit(fn, *fn_args, **fn_kwargs) + + @staticmethod + def yield_to_event_loop() -> None: + """ + No-op for sync version + """ + pass + + @staticmethod + def verify_async_event_loop() -> None: + """ + No-op for sync version + """ + pass From 8af643842cb19aff08e597670f81452748f25944 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 13:38:26 -0700 Subject: [PATCH 166/360] added cross sync ast transform system --- .cross_sync/generate.py | 92 ++++++++++++ .cross_sync/transformers.py | 276 ++++++++++++++++++++++++++++++++++++ 2 files changed, 368 insertions(+) create mode 100644 .cross_sync/generate.py create mode 100644 .cross_sync/transformers.py diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py new file mode 100644 index 000000000..d3a234ff7 --- /dev/null +++ b/.cross_sync/generate.py @@ -0,0 +1,92 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations +import ast +from dataclasses import dataclass, field + + +@dataclass +class CrossSyncFileArtifact: + """ + Used to track an output file location. Collects a number of converted classes, and then + writes them to disk + """ + + file_path: str + imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = field( + default_factory=list + ) + converted_classes: list[ast.ClassDef] = field(default_factory=list) + contained_classes: set[str] = field(default_factory=set) + mypy_ignore: list[str] = field(default_factory=list) + + def __hash__(self): + return hash(self.file_path) + + def __repr__(self): + return f"CrossSyncFileArtifact({self.file_path}, classes={[c.name for c in self.converted_classes]})" + + def render(self, with_black=True, save_to_disk=False) -> str: + full_str = ( + "# Copyright 2024 Google LLC\n" + "#\n" + '# Licensed under the Apache License, Version 2.0 (the "License");\n' + "# you may not use this file except in compliance with the License.\n" + "# You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing, software\n" + '# distributed under the License is distributed on an "AS IS" BASIS,\n' + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + "# See the License for the specific language governing permissions and\n" + "# limitations under the License.\n" + "#\n" + "# This file is automatically generated by CrossSync. Do not edit manually.\n" + ) + if self.mypy_ignore: + full_str += ( + f'\n# mypy: disable-error-code="{",".join(self.mypy_ignore)}"\n\n' + ) + full_str += "\n".join([ast.unparse(node) for node in self.imports]) # type: ignore + full_str += "\n\n" + full_str += "\n".join([ast.unparse(node) for node in self.converted_classes]) # type: ignore + if with_black: + import black # type: ignore + import autoflake # type: ignore + + full_str = black.format_str( + autoflake.fix_code(full_str, remove_all_unused_imports=True), + mode=black.FileMode(), + ) + if save_to_disk: + with open(self.file_path, "w") as f: + f.write(full_str) + return full_str + +if __name__ == "__main__": + import glob + import sys + from transformers import CrossSyncClassDecoratorHandler + + # find all cross_sync decorated classes + search_root = sys.argv[1] + files = glob.glob(search_root + "/**/*.py", recursive=True) + artifacts: set[CrossSyncFileArtifact] = set() + for file in files: + converter = CrossSyncClassDecoratorHandler(file) + converter.convert_file(artifacts) + print(artifacts) + for artifact in artifacts: + artifact.render(save_to_disk=True) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py new file mode 100644 index 000000000..fff626ca3 --- /dev/null +++ b/.cross_sync/transformers.py @@ -0,0 +1,276 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import ast + +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + + +class SymbolReplacer(ast.NodeTransformer): + """ + Replaces all instances of a symbol in an AST with a replacement + + Works for function signatures, method calls, docstrings, and type annotations + """ + def __init__(self, replacements: dict[str, str]): + self.replacements = replacements + + def visit_Name(self, node): + if node.id in self.replacements: + node.id = self.replacements[node.id] + return node + + def visit_Attribute(self, node): + return ast.copy_location( + ast.Attribute( + self.visit(node.value), + self.replacements.get(node.attr, node.attr), + node.ctx, + ), + node, + ) + + def visit_AsyncFunctionDef(self, node): + """ + Replace async function docstrings + """ + # use same logic as FunctionDef + return self.visit_FunctionDef(node) + + def visit_FunctionDef(self, node): + """ + Replace function docstrings + """ + docstring = ast.get_docstring(node) + if docstring and isinstance(node.body[0], ast.Expr) and isinstance( + node.body[0].value, ast.Str + ): + for key_word, replacement in self.replacements.items(): + docstring = docstring.replace(f" {key_word} ", f" {replacement} ") + node.body[0].value.s = docstring + return self.generic_visit(node) + + def visit_Str(self, node): + """Replace string type annotations""" + node.s = self.replacements.get(node.s, node.s) + return node + + +class AsyncToSync(ast.NodeTransformer): + """ + Replaces or strips all async keywords from a given AST + """ + def visit_Await(self, node): + """ + Strips await keyword + """ + return self.visit(node.value) + + def visit_AsyncFor(self, node): + """ + Replaces `async for` with `for` + """ + return ast.copy_location( + ast.For( + self.visit(node.target), + self.visit(node.iter), + [self.visit(stmt) for stmt in node.body], + [self.visit(stmt) for stmt in node.orelse], + ), + node, + ) + + def visit_AsyncWith(self, node): + """ + Replaces `async with` with `with` + """ + return ast.copy_location( + ast.With( + [self.visit(item) for item in node.items], + [self.visit(stmt) for stmt in node.body], + ), + node, + ) + + def visit_AsyncFunctionDef(self, node): + """ + Replaces `async def` with `def` + """ + return ast.copy_location( + ast.FunctionDef( + node.name, + self.visit(node.args), + [self.visit(stmt) for stmt in node.body], + [self.visit(decorator) for decorator in node.decorator_list], + node.returns and self.visit(node.returns), + ), + node, + ) + + def visit_ListComp(self, node): + """ + Replaces `async for` with `for` in list comprehensions + """ + for generator in node.generators: + generator.is_async = False + return self.generic_visit(node) + + +class CrossSyncMethodDecoratorHandler(ast.NodeTransformer): + """ + Visits each method in a class, and handles any CrossSync decorators found + """ + + def visit_FunctionDef(self, node): + return self.visit_AsyncFunctionDef(node) + + def visit_AsyncFunctionDef(self, node): + try: + if hasattr(node, "decorator_list"): + found_list, node.decorator_list = node.decorator_list, [] + for decorator in found_list: + if decorator == CrossSync.convert: + # convert async to sync + kwargs = CrossSync.convert.parse_ast_keywords(decorator) + node = AsyncToSync().visit(node) + # replace method name if specified + if kwargs["sync_name"] is not None: + node.name = kwargs["sync_name"] + # replace symbols if specified + if kwargs["replace_symbols"]: + node = SymbolReplacer(kwargs["replace_symbols"]).visit(node) + elif decorator == CrossSync.drop_method: + # drop method entirely from class + return None + elif decorator == CrossSync.pytest: + # also convert pytest methods to sync + node = AsyncToSync().visit(node) + elif decorator == CrossSync.pytest_fixture: + # add pytest.fixture decorator + decorator.func.value = ast.Name(id="pytest", ctx=ast.Load()) + decorator.func.attr = "fixture" + node.decorator_list.append(decorator) + else: + # keep unknown decorators + node.decorator_list.append(decorator) + return node + except ValueError as e: + raise ValueError(f"node {node.name} failed") from e + + +class CrossSyncClassDecoratorHandler(ast.NodeTransformer): + """ + Visits each class in the file, and if it has a CrossSync decorator, it will be transformed. + + Uses CrossSyncMethodDecoratorHandler to visit and (potentially) convert each method in the class + """ + def __init__(self, file_path): + self.in_path = file_path + self._artifact_dict: dict[str, CrossSyncFileArtifact] = {} + self.imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = [] + self.cross_sync_symbol_transformer = SymbolReplacer( + {"CrossSync": "CrossSync._Sync_Impl"} + ) + self.cross_sync_method_handler = CrossSyncMethodDecoratorHandler() + + def convert_file( + self, artifacts: set[CrossSyncFileArtifact] | None = None + ) -> set[CrossSyncFileArtifact]: + """ + Called to run a file through the transformer. If any classes are marked with a CrossSync decorator, + they will be transformed and added to an artifact for the output file + """ + tree = ast.parse(open(self.in_path).read()) + self._artifact_dict = {f.file_path: f for f in artifacts or []} + self.imports = self._get_imports(tree) + self.visit(tree) + found = set(self._artifact_dict.values()) + if artifacts is not None: + artifacts.update(found) + return found + + def visit_ClassDef(self, node): + """ + Called for each class in file. If class has a CrossSync decorator, it will be transformed + according to the decorator arguments + """ + try: + for decorator in node.decorator_list: + if decorator == CrossSync.export_sync: + kwargs = CrossSync.export_sync.parse_ast_keywords(decorator) + # find the path to write the sync class to + sync_path = kwargs["path"] + out_file = "/".join(sync_path.rsplit(".")[:-1]) + ".py" + sync_cls_name = sync_path.rsplit(".", 1)[-1] + # find the artifact file for the save location + output_artifact = self._artifact_dict.get( + out_file, CrossSyncFileArtifact(out_file) + ) + # write converted class details if not already present + if sync_cls_name not in output_artifact.contained_classes: + converted = self._transform_class(node, sync_cls_name, **kwargs) + output_artifact.converted_classes.append(converted) + # handle file-level mypy ignores + mypy_ignores = [ + s + for s in kwargs["mypy_ignore"] + if s not in output_artifact.mypy_ignore + ] + output_artifact.mypy_ignore.extend(mypy_ignores) + # handle file-level imports + if not output_artifact.imports and kwargs["include_file_imports"]: + output_artifact.imports = self.imports + self._artifact_dict[out_file] = output_artifact + return node + except ValueError as e: + raise ValueError(f"failed for class: {node.name}") from e + + def _transform_class( + self, + cls_ast: ast.ClassDef, + new_name: str, + replace_symbols: dict[str, str] | None = None, + **kwargs, + ) -> ast.ClassDef: + """ + Transform async class into sync one, by running through a series of transformers + """ + # update name + cls_ast.name = new_name + # strip CrossSync decorators + if hasattr(cls_ast, "decorator_list"): + cls_ast.decorator_list = [ + d for d in cls_ast.decorator_list if "CrossSync" not in ast.dump(d) + ] + # convert class contents + cls_ast = self.cross_sync_symbol_transformer.visit(cls_ast) + if replace_symbols: + cls_ast = SymbolReplacer(replace_symbols).visit(cls_ast) + cls_ast = self.cross_sync_method_handler.visit(cls_ast) + return cls_ast + + def _get_imports( + self, tree: ast.Module + ) -> list[ast.Import | ast.ImportFrom | ast.Try | ast.If]: + """ + Grab the imports from the top of the file + """ + imports = [] + for node in tree.body: + if isinstance(node, (ast.Import, ast.ImportFrom, ast.Try, ast.If)): + imports.append(self.cross_sync_symbol_transformer.visit(node)) + return imports + + From 8d763f83008c5c8dd4f444717792f40f545486b2 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 13:59:12 -0700 Subject: [PATCH 167/360] added conversion annotations This reverts commit 245bd0894a61b3f309d79cdfa040dd344eabc29f. --- .../bigtable/data/_async/_mutate_rows.py | 11 ++++ .../cloud/bigtable/data/_async/_read_rows.py | 8 +++ google/cloud/bigtable/data/_async/client.py | 43 +++++++++++++ .../bigtable/data/_async/mutations_batcher.py | 22 +++++++ .../cloud/bigtable/data/_sync/cross_sync.py | 11 ++++ tests/system/data/test_system_async.py | 11 ++++ tests/unit/data/_async/test__mutate_rows.py | 4 ++ tests/unit/data/_async/test__read_rows.py | 11 ++++ tests/unit/data/_async/test_client.py | 64 +++++++++++++++++++ .../data/_async/test_mutations_batcher.py | 15 +++++ .../data/_async/test_read_rows_acceptance.py | 11 ++++ 11 files changed, 211 insertions(+) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 7c40b492c..3feb64b68 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -40,6 +40,9 @@ ) +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", +) class _MutateRowsOperationAsync: """ MutateRowsOperation manages the logic of sending a set of row mutations, @@ -59,6 +62,12 @@ class _MutateRowsOperationAsync: If not specified, the request will run until operation_timeout is reached. """ + @CrossSync.convert( + replace_symbols={ + "BigtableAsyncClient": "BigtableClient", + "TableAsync": "Table", + } + ) def __init__( self, gapic_client: "BigtableAsyncClient", @@ -108,6 +117,7 @@ def __init__( self.remaining_indices = list(range(len(self.mutations))) self.errors: dict[int, list[Exception]] = {} + @CrossSync.convert async def start(self): """ Start the operation, and run until completion @@ -144,6 +154,7 @@ async def start(self): all_errors, len(self.mutations) ) + @CrossSync.convert async def _run_attempt(self): """ Run a single attempt of the mutate_rows rpc. diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index dfc9c1adb..989430f64 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -45,6 +45,9 @@ from google.cloud.bigtable.data._async.client import TableAsync +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", +) class _ReadRowsOperationAsync: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream @@ -76,6 +79,7 @@ class _ReadRowsOperationAsync: "_remaining_count", ) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def __init__( self, query: ReadRowsQuery, @@ -156,6 +160,7 @@ def _read_rows_attempt(self) -> CrossSync.Iterable[Row]: chunked_stream = self.chunk_stream(gapic_stream) return self.merge_rows(chunked_stream) + @CrossSync.convert async def chunk_stream( self, stream: CrossSync.Awaitable[CrossSync.Iterable[ReadRowsResponsePB]] ) -> CrossSync.Iterable[ReadRowsResponsePB.CellChunk]: @@ -208,6 +213,9 @@ async def chunk_stream( current_key = None @staticmethod + @CrossSync.convert( + replace_symbols={"__aiter__": "__iter__", "__anext__": "__next__"} + ) async def merge_rows( chunks: CrossSync.Iterable[ReadRowsResponsePB.CellChunk] | None, ) -> CrossSync.Iterable[Row]: diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index a2ac7f2cc..64efe25b5 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -97,7 +97,17 @@ from google.cloud.bigtable.data._helpers import ShardedQuery +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync.client.BigtableDataClient", +) class BigtableDataClientAsync(ClientWithProject): + @CrossSync.convert( + replace_symbols={ + "BigtableAsyncClient": "BigtableClient", + "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", + "AsyncPooledChannel": "PooledChannel", + } + ) def __init__( self, *, @@ -238,6 +248,7 @@ def _start_background_channel_refresh(self) -> None: ) self._channel_refresh_tasks.append(refresh_task) + @CrossSync.convert async def close(self, timeout: float | None = 2.0): """ Cancel all background tasks @@ -251,6 +262,7 @@ async def close(self, timeout: float | None = 2.0): await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) self._channel_refresh_tasks = [] + @CrossSync.convert async def _ping_and_warm_instances( self, channel: Channel, instance_key: _WarmedInstanceKey | None = None ) -> list[BaseException | None]: @@ -292,6 +304,7 @@ async def _ping_and_warm_instances( ) return [r or None for r in result_list] + @CrossSync.convert async def _manage_channel( self, channel_idx: int, @@ -351,6 +364,7 @@ async def _manage_channel( next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = next_refresh - (time.monotonic() - start_timestamp) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: """ Registers an instance with the client, and warms the channel pool @@ -381,6 +395,7 @@ async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: # refresh tasks aren't active. start them as background tasks self._start_background_channel_refresh() + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def _remove_instance_registration( self, instance_id: str, owner: TableAsync ) -> bool: @@ -411,6 +426,7 @@ async def _remove_instance_registration( except KeyError: return False + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: """ Returns a table instance for making data API requests. All arguments are passed @@ -452,15 +468,18 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs """ return TableAsync(self, instance_id, table_id, *args, **kwargs) + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): self._start_background_channel_refresh() return self + @CrossSync.convert(sync_name="__exit__", replace_symbols={"__aexit__": "__exit__"}) async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) +@CrossSync.export_sync(path="google.cloud.bigtable.data._sync.client.Table") class TableAsync: """ Main Data API surface @@ -469,6 +488,9 @@ class TableAsync: each call """ + @CrossSync.convert( + replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} + ) def __init__( self, client: BigtableDataClientAsync, @@ -589,6 +611,12 @@ def __init__( f"{self.__class__.__name__} must be created within an async event loop context." ) from e + @CrossSync.convert( + replace_symbols={ + "AsyncIterable": "Iterable", + "_ReadRowsOperationAsync": "_ReadRowsOperation", + } + ) async def read_rows_stream( self, query: ReadRowsQuery, @@ -639,6 +667,7 @@ async def read_rows_stream( ) return row_merger.start_operation() + @CrossSync.convert async def read_rows( self, query: ReadRowsQuery, @@ -686,6 +715,7 @@ async def read_rows( ) return [row async for row in row_generator] + @CrossSync.convert async def read_row( self, row_key: str | bytes, @@ -735,6 +765,7 @@ async def read_row( return None return results[0] + @CrossSync.convert async def read_rows_sharded( self, sharded_query: ShardedQuery, @@ -834,6 +865,7 @@ async def read_rows_with_semaphore(query): ) return results_list + @CrossSync.convert async def row_exists( self, row_key: str | bytes, @@ -882,6 +914,7 @@ async def row_exists( ) return len(results) > 0 + @CrossSync.convert async def sample_row_keys( self, *, @@ -954,6 +987,7 @@ async def execute_rpc(): exception_factory=_retry_exception_factory, ) + @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) def mutations_batcher( self, *, @@ -1003,6 +1037,7 @@ def mutations_batcher( batch_retryable_errors=batch_retryable_errors, ) + @CrossSync.convert async def mutate_row( self, row_key: str | bytes, @@ -1081,6 +1116,9 @@ async def mutate_row( exception_factory=_retry_exception_factory, ) + @CrossSync.convert( + replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"} + ) async def bulk_mutate_rows( self, mutation_entries: list[RowMutationEntry], @@ -1136,6 +1174,7 @@ async def bulk_mutate_rows( ) await operation.start() + @CrossSync.convert async def check_and_mutate_row( self, row_key: str | bytes, @@ -1202,6 +1241,7 @@ async def check_and_mutate_row( ) return result.predicate_matched + @CrossSync.convert async def read_modify_write_row( self, row_key: str | bytes, @@ -1252,6 +1292,7 @@ async def read_modify_write_row( # construct Row from result return Row._from_pb(result.row) + @CrossSync.convert async def close(self): """ Called to close the Table instance and release any resources held by it. @@ -1260,6 +1301,7 @@ async def close(self): self._register_instance_future.cancel() await self.client._remove_instance_registration(self.instance_id, self) + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """ Implement async context manager protocol @@ -1271,6 +1313,7 @@ async def __aenter__(self): await self._register_instance_future return self + @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc_val, exc_tb): """ Implement async context manager protocol diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index caa35425c..53aea4db0 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -45,6 +45,9 @@ from google.cloud.bigtable.data._async.client import TableAsync +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync.mutations_batcher._FlowControl" +) class _FlowControlAsync: """ Manages flow control for batched mutations. Mutations are registered against @@ -101,6 +104,7 @@ def _has_capacity(self, additional_count: int, additional_size: int) -> bool: new_count = self._in_flight_mutation_count + additional_count return new_size <= acceptable_size and new_count <= acceptable_count + @CrossSync.convert async def remove_from_flow( self, mutations: RowMutationEntry | list[RowMutationEntry] ) -> None: @@ -122,6 +126,7 @@ async def remove_from_flow( async with self._capacity_condition: self._capacity_condition.notify_all() + @CrossSync.convert async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): """ Generator function that registers mutations with flow control. As mutations @@ -171,6 +176,10 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] yield mutations[start_idx:end_idx] +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", + mypy_ignore=["unreachable"], +) class MutationsBatcherAsync: """ Allows users to send batches using context manager API: @@ -202,6 +211,9 @@ class MutationsBatcherAsync: Defaults to the Table's default_mutate_rows_retryable_errors. """ + @CrossSync.convert( + replace_symbols={"TableAsync": "Table", "_FlowControlAsync": "_FlowControl"} + ) def __init__( self, table: TableAsync, @@ -257,6 +269,7 @@ def __init__( # clean up on program exit atexit.register(self._on_exit) + @CrossSync.convert async def _timer_routine(self, interval: float | None) -> None: """ Set up a background task to flush the batcher every interval seconds @@ -277,6 +290,7 @@ async def _timer_routine(self, interval: float | None) -> None: if not self._closed.is_set() and self._staged_entries: self._schedule_flush() + @CrossSync.convert async def append(self, mutation_entry: RowMutationEntry): """ Add a new set of mutations to the internal queue @@ -326,6 +340,7 @@ def _schedule_flush(self) -> CrossSync.Future[None] | None: return new_task return None + @CrossSync.convert async def _flush_internal(self, new_entries: list[RowMutationEntry]): """ Flushes a set of mutations to the server, and updates internal state @@ -346,6 +361,9 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): self._entries_processed_since_last_raise += len(new_entries) self._add_exceptions(found_exceptions) + @CrossSync.convert( + replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"} + ) async def _execute_mutate_rows( self, batch: list[RowMutationEntry] ) -> list[FailedMutationEntryError]: @@ -426,10 +444,12 @@ def _raise_exceptions(self): entry_count=entry_count, ) + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """Allow use of context manager API""" return self + @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc, tb): """ Allow use of context manager API. @@ -446,6 +466,7 @@ def closed(self) -> bool: """ return self._closed.is_set() + @CrossSync.convert async def close(self): """ Flush queue and clean up resources @@ -473,6 +494,7 @@ def _on_exit(self): ) @staticmethod + @CrossSync.convert async def _wait_for_batch_results( *tasks: CrossSync.Future[list[FailedMutationEntryError]] | CrossSync.Future[None], diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 82e708b70..3bdc40b3a 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -220,6 +220,17 @@ class CrossSync(metaclass=_DecoratorMeta): Generator: TypeAlias = AsyncGenerator _decorators: list[AstDecorator] = [ + AstDecorator("export_sync", # decorate classes to convert + required_keywords=["path"], # otput path for generated sync class + replace_symbols={}, # replace specific symbols across entire class + mypy_ignore=(), # set of mypy error codes to ignore in output file + include_file_imports=True # when True, import statements from top of file will be included in output file + ), + AstDecorator("convert", # decorate methods to convert from async to sync + sync_name=None, # use a new name for the sync class + replace_symbols={}, # replace specific symbols within the function + ), + AstDecorator("drop_method"), # decorate methods to drop in sync version of class AstDecorator( "pytest", async_impl=pytest_mark_asyncio ), # decorate test methods to run with pytest-asyncio diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index cbd8d7605..5c37d64d2 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -30,6 +30,7 @@ from google.cloud.bigtable.data._async.client import BigtableDataClientAsync +@CrossSync.export_sync(path="tests.system.data.test_system.TempRowBuilder") class TempRowBuilderAsync: """ Used to add rows to a table for testing purposes. @@ -39,6 +40,7 @@ def __init__(self, table): self.rows = [] self.table = table + @CrossSync.convert async def add_row( self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" ): @@ -62,6 +64,7 @@ async def add_row( await self.table.client._gapic_client.mutate_row(request) self.rows.append(row_key) + @CrossSync.convert async def delete_rows(self): if self.rows: request = { @@ -74,13 +77,16 @@ async def delete_rows(self): await self.table.client._gapic_client.mutate_rows(request) +@CrossSync.export_sync(path="tests.system.data.test_system.TestSystem") class TestSystemAsync: + @CrossSync.convert(replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}) @CrossSync.pytest_fixture(scope="session") async def client(self): project = os.getenv("GOOGLE_CLOUD_PROJECT") or None async with BigtableDataClientAsync(project=project, pool_size=4) as client: yield client + @CrossSync.convert @CrossSync.pytest_fixture(scope="session") async def table(self, client, table_id, instance_id): async with client.get_table( @@ -127,6 +133,7 @@ def cluster_config(self, project_id): } return cluster + @CrossSync.convert @pytest.mark.usefixtures("table") async def _retrieve_cell_value(self, table, row_key): """ @@ -140,6 +147,7 @@ async def _retrieve_cell_value(self, table, row_key): cell = row.cells[0] return cell.value + @CrossSync.convert async def _create_row_and_mutation( self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" ): @@ -160,6 +168,7 @@ async def _create_row_and_mutation( mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) return row_key, mutation + @CrossSync.convert(replace_symbols={"TempRowBuilderAsync": "TempRowBuilder"}) @CrossSync.pytest_fixture(scope="function") async def temp_rows(self, table): builder = TempRowBuilderAsync(table) @@ -650,6 +659,7 @@ async def test_check_and_mutate_empty_request(self, client, table): assert "No mutations provided" in str(e.value) @pytest.mark.usefixtures("table") + @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @@ -838,6 +848,7 @@ async def test_read_rows_with_filter(self, table, temp_rows): assert row[0].labels == [expected_label] @pytest.mark.usefixtures("table") + @CrossSync.convert(replace_symbols={"__anext__": "__next__", "aclose": "close"}) @CrossSync.pytest async def test_read_rows_stream_close(self, table, temp_rows): """ diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index b7016be81..292cbd692 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -28,6 +28,9 @@ import mock # type: ignore +@CrossSync.export_sync( + path="tests.unit.data._sync.test__mutate_rows.TestMutateRowsOperation", +) class TestMutateRowsOperation: def _target_class(self): if CrossSync.is_async: @@ -59,6 +62,7 @@ def _make_mutation(self, count=1, size=1): mutation.mutations = [mock.Mock()] * count return mutation + @CrossSync.convert async def _mock_stream(self, mutation_list, error_dict): for idx, entry in enumerate(mutation_list): code = error_dict.get(idx, 0) diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 5d9957e1f..ff55ffd20 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -34,6 +34,9 @@ TEST_LABELS = ["label1", "label2"] +@CrossSync.export_sync( + path="tests.unit.data._sync.test__read_rows.TestReadRowsOperation", +) class TestReadRowsOperationAsync: """ Tests helper functions in the ReadRowsOperation class @@ -42,6 +45,9 @@ class TestReadRowsOperationAsync: """ @staticmethod + @CrossSync.convert( + replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} + ) def _get_target_class(): return _ReadRowsOperationAsync @@ -327,6 +333,10 @@ async def mock_stream(): assert "emit count exceeds row limit" in str(e.value) @CrossSync.pytest + @CrossSync.convert( + sync_name="test_close", + replace_symbols={"aclose": "close", "__anext__": "__next__"}, + ) async def test_aclose(self): """ should be able to close a stream safely with aclose. @@ -356,6 +366,7 @@ async def mock_stream(): await wrapped_gen.__anext__() @CrossSync.pytest + @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) async def test_retryable_ignore_repeated_rows(self): """ Duplicate rows should cause an invalid chunk error diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 5370f35d3..0f5775fac 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -76,8 +76,21 @@ ) +@CrossSync.export_sync( + path="tests.unit.data._sync.test_client.TestBigtableDataClient", + replace_symbols={ + "TestTableAsync": "TestTable", + "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", + "grpc_helpers_async": "grpc_helpers", + "PooledChannelAsync": "PooledChannel", + "BigtableAsyncClient": "BigtableClient", + }, +) class TestBigtableDataClientAsync: @staticmethod + @CrossSync.convert( + replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} + ) def _get_target_class(): return BigtableDataClientAsync @@ -287,6 +300,7 @@ async def test_channel_pool_replace(self): assert client.transport._grpc_channel._pool[i] != start_pool[i] await client.close() + @CrossSync.drop_method @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test__start_background_channel_refresh_sync(self): # should raise RuntimeError if called in a sync context @@ -330,6 +344,7 @@ async def test__start_background_channel_refresh(self, pool_size): ping_and_warm.assert_any_call(channel) await client.close() + @CrossSync.drop_method @CrossSync.pytest @pytest.mark.skipif( sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" @@ -1110,6 +1125,7 @@ async def test_context_manager(self): # actually close the client await true_close + @CrossSync.drop_method def test_client_ctor_sync(self): # initializing client in a sync context should raise RuntimeError @@ -1125,11 +1141,16 @@ def test_client_ctor_sync(self): assert client._channel_refresh_tasks == [] +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestTable") class TestTableAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @staticmethod + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def _get_target_class(): return TableAsync @@ -1251,6 +1272,7 @@ async def test_table_ctor_invalid_timeout_values(self): assert "operation_timeout must be greater than 0" in str(e.value) await client.close() + @CrossSync.drop_method def test_table_ctor_sync(self): # initializing client in a sync context should raise RuntimeError client = mock.Mock() @@ -1400,6 +1422,7 @@ async def test_customizable_retryable_errors( ) @pytest.mark.parametrize("include_app_profile", [True, False]) @CrossSync.pytest + @CrossSync.convert(replace_symbols={"BigtableAsyncClient": "BigtableClient"}) async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): """check that all requests attach proper metadata headers""" profile = "profile" if include_app_profile else None @@ -1432,18 +1455,26 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ assert "app_profile_id=" not in goog_metadata +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRows") class TestReadRowsAsync: """ Tests for table.read_rows and related methods. """ @staticmethod + @CrossSync.convert( + replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} + ) def _get_operation_class(): return _ReadRowsOperationAsync + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @CrossSync.convert(replace_symbols={"TestTableAsync": "TestTable"}) def _make_table(self, *args, **kwargs): client_mock = mock.Mock() client_mock._register_instance.side_effect = ( @@ -1491,6 +1522,7 @@ def _make_chunk(*args, **kwargs): return ReadRowsResponse.CellChunk(*args, **kwargs) @staticmethod + @CrossSync.convert async def _make_gapic_stream( chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0, @@ -1529,6 +1561,7 @@ def cancel(self): return mock_stream(chunk_list, sleep_time) + @CrossSync.convert async def execute_fn(self, table, *args, **kwargs): return await table.read_rows(*args, **kwargs) @@ -1940,7 +1973,11 @@ async def test_row_exists(self, return_value, expected_result): assert query.filter._to_dict() == expected_filter +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRowsSharded") class TestReadRowsShardedAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -1953,6 +1990,7 @@ async def test_read_rows_sharded_empty_query(self): assert "empty sharded_query" in str(exc.value) @CrossSync.pytest + @CrossSync.convert(replace_symbols={"TestReadRowsAsync": "TestReadRows"}) async def test_read_rows_sharded_multiple_queries(self): """ Test with multiple queries. Should return results from both @@ -2160,10 +2198,15 @@ async def mock_call(*args, **kwargs): ) +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestSampleRowKeys") class TestSampleRowKeysAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @CrossSync.convert async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): from google.cloud.bigtable_v2.types import SampleRowKeysResponse @@ -2311,7 +2354,13 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio await table.sample_row_keys() +@CrossSync.export_sync( + path="tests.unit.data._sync.test_client.TestMutateRow", +) class TestMutateRowAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2486,10 +2535,17 @@ async def test_mutate_row_no_mutations(self, mutations): assert e.value.args[0] == "No mutations provided" +@CrossSync.export_sync( + path="tests.unit.data._sync.test_client.TestBulkMutateRows", +) class TestBulkMutateRowsAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @CrossSync.convert async def _mock_response(self, response_list): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -2865,7 +2921,11 @@ async def test_bulk_mutate_error_recovery(self): await table.bulk_mutate_rows(entries, operation_timeout=1000) +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestCheckAndMutateRow") class TestCheckAndMutateRowAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -3016,7 +3076,11 @@ async def test_check_and_mutate_mutations_parsing(self): ) +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadModifyWriteRow") class TestReadModifyWriteRowAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index a26dcfd64..cbf84e798 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -43,8 +43,10 @@ import mock # type: ignore +@CrossSync.export_sync(path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl") class Test_FlowControl: @staticmethod + @CrossSync.convert(replace_symbols={"_FlowControlAsync": "_FlowControl"}) def _target_class(): return _FlowControlAsync @@ -311,7 +313,11 @@ async def test_add_to_flow_oversize(self): assert len(count_results) == 1 +@CrossSync.export_sync( + path="tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" +) class TestMutationsBatcherAsync: + @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) def _get_target_class(self): return MutationsBatcherAsync @@ -473,6 +479,7 @@ async def test_ctor_invalid_values(self): self._make_one(batch_attempt_timeout=-1) assert "attempt_timeout must be greater than 0" in str(e.value) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def test_default_argument_consistency(self): """ We supply default arguments in MutationsBatcherAsync.__init__, and in @@ -893,6 +900,7 @@ async def gen(x): instance._oldest_exceptions.clear() instance._newest_exceptions.clear() + @CrossSync.convert async def _mock_gapic_return(self, num=5): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -1008,12 +1016,18 @@ async def test__raise_exceptions(self): instance._raise_exceptions() @CrossSync.pytest + @CrossSync.convert( + sync_name="test___enter__", replace_symbols={"__aenter__": "__enter__"} + ) async def test___aenter__(self): """Should return self""" async with self._make_one() as instance: assert await instance.__aenter__() == instance @CrossSync.pytest + @CrossSync.convert( + sync_name="test___exit__", replace_symbols={"__aexit__": "__exit__"} + ) async def test___aexit__(self): """aexit should call close""" async with self._make_one() as instance: @@ -1197,6 +1211,7 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): ([4], [core_exceptions.DeadlineExceeded]), ], ) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def test_customizable_retryable_errors( self, input_retryables, expected_retryables ): diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 15b181637..6dc4ed79c 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -39,12 +39,21 @@ from google.cloud.bigtable.data._sync.client import BigtableDataClient # noqa: F401 +@CrossSync.export_sync( + path="tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", +) class TestReadRowsAcceptanceAsync: @staticmethod + @CrossSync.convert( + replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} + ) def _get_operation_class(): return _ReadRowsOperationAsync @staticmethod + @CrossSync.convert( + replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} + ) def _get_client_class(): return BigtableDataClientAsync @@ -74,9 +83,11 @@ def extract_results_from_row(row: Row): return results @staticmethod + @CrossSync.convert async def _coro_wrapper(stream): return stream + @CrossSync.convert async def _process_chunks(self, *chunks): async def _row_stream(): yield ReadRowsResponse(chunks=chunks) From 5efc8408c16a85f868c6a88fa6e43302fcc5363c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 14:02:27 -0700 Subject: [PATCH 168/360] fixed import --- .cross_sync/transformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index fff626ca3..1e5ba5b62 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -16,6 +16,7 @@ import ast from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from generate import CrossSyncFileArtifact class SymbolReplacer(ast.NodeTransformer): From eb5cd4894aedad2b0b0d1372ba564462fd71d9e6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 14:10:57 -0700 Subject: [PATCH 169/360] refactor outputs --- .cross_sync/generate.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index d3a234ff7..117f7da8d 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations +from typing import Sequence import ast from dataclasses import dataclass, field @@ -75,18 +76,31 @@ def render(self, with_black=True, save_to_disk=False) -> str: f.write(full_str) return full_str -if __name__ == "__main__": + +def convert_files_in_dir(directory: str) -> set[CrossSyncFileArtifact]: import glob - import sys from transformers import CrossSyncClassDecoratorHandler - # find all cross_sync decorated classes - search_root = sys.argv[1] - files = glob.glob(search_root + "/**/*.py", recursive=True) + # find all python files in the directory + files = glob.glob(directory + "/**/*.py", recursive=True) + # keep track of the output sync files pointed to by the input files artifacts: set[CrossSyncFileArtifact] = set() + # run each file through ast transformation to find all annotated classes for file in files: converter = CrossSyncClassDecoratorHandler(file) converter.convert_file(artifacts) - print(artifacts) + # return set of output artifacts + return artifacts + +def save_artifacts(artifacts: Sequence[CrossSyncFileArtifact]): for artifact in artifacts: artifact.render(save_to_disk=True) + + +if __name__ == "__main__": + import sys + + search_root = sys.argv[1] + outputs = convert_files_in_dir(search_root) + print(f"Generated {len(outputs)} artifacts: {[a.file_path for a in outputs]}") + save_artifacts(outputs) From e5a8792455b4b98b411f2cdeed4aa23460375058 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 14:34:35 -0700 Subject: [PATCH 170/360] renamed artifact class; added comments --- .cross_sync/generate.py | 40 +++++++++++++++++++++++++++---------- .cross_sync/transformers.py | 34 +++++++++++++++++++------------ 2 files changed, 50 insertions(+), 24 deletions(-) diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index 117f7da8d..c92d700a2 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -18,27 +18,44 @@ @dataclass -class CrossSyncFileArtifact: +class CrossSyncOutputFile: """ - Used to track an output file location. Collects a number of converted classes, and then - writes them to disk + Represents an output file location. + + Multiple decorated async classes may point to the same output location for + their generated sync code. This class holds all the information needed to + write the output file to disk. """ + # The path to the output file file_path: str + # The import headers to write to the top of the output file + # will be populated when CrossSync.export_sync(include_file_imports=True) imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = field( default_factory=list ) + # The set of sync ast.ClassDef nodes to write to the output file converted_classes: list[ast.ClassDef] = field(default_factory=list) + # the set of classes contained in the file. Used to prevent duplicates contained_classes: set[str] = field(default_factory=set) + # the set of mypy error codes to ignore at the file level + # configured using CrossSync.export_sync(mypy_ignore=["error_code"]) mypy_ignore: list[str] = field(default_factory=list) def __hash__(self): return hash(self.file_path) def __repr__(self): - return f"CrossSyncFileArtifact({self.file_path}, classes={[c.name for c in self.converted_classes]})" + return f"CrossSyncOutputFile({self.file_path}, classes={[c.name for c in self.converted_classes]})" def render(self, with_black=True, save_to_disk=False) -> str: + """ + Render the output file as a string. + + Args: + with_black: whether to run the output through black before returning + save_to_disk: whether to write the output to the file path + """ full_str = ( "# Copyright 2024 Google LLC\n" "#\n" @@ -77,24 +94,25 @@ def render(self, with_black=True, save_to_disk=False) -> str: return full_str -def convert_files_in_dir(directory: str) -> set[CrossSyncFileArtifact]: +def convert_files_in_dir(directory: str) -> set[CrossSyncOutputFile]: import glob from transformers import CrossSyncClassDecoratorHandler # find all python files in the directory files = glob.glob(directory + "/**/*.py", recursive=True) - # keep track of the output sync files pointed to by the input files - artifacts: set[CrossSyncFileArtifact] = set() + # keep track of the output files pointed to by the annotated classes + artifacts: set[CrossSyncOutputFile] = set() # run each file through ast transformation to find all annotated classes for file in files: converter = CrossSyncClassDecoratorHandler(file) - converter.convert_file(artifacts) + new_outputs = converter.convert_file(artifacts) + artifacts.update(new_outputs) # return set of output artifacts return artifacts -def save_artifacts(artifacts: Sequence[CrossSyncFileArtifact]): - for artifact in artifacts: - artifact.render(save_to_disk=True) +def save_artifacts(artifacts: Sequence[CrossSyncOutputFile]): + for a in artifacts: + a.render(save_to_disk=True) if __name__ == "__main__": diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 1e5ba5b62..f3a11c249 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -16,7 +16,7 @@ import ast from google.cloud.bigtable.data._sync.cross_sync import CrossSync -from generate import CrossSyncFileArtifact +from generate import CrossSyncOutputFile class SymbolReplacer(ast.NodeTransformer): @@ -179,7 +179,7 @@ class CrossSyncClassDecoratorHandler(ast.NodeTransformer): """ def __init__(self, file_path): self.in_path = file_path - self._artifact_dict: dict[str, CrossSyncFileArtifact] = {} + self._artifact_dict: dict[str, CrossSyncOutputFile] = {} self.imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = [] self.cross_sync_symbol_transformer = SymbolReplacer( {"CrossSync": "CrossSync._Sync_Impl"} @@ -187,25 +187,29 @@ def __init__(self, file_path): self.cross_sync_method_handler = CrossSyncMethodDecoratorHandler() def convert_file( - self, artifacts: set[CrossSyncFileArtifact] | None = None - ) -> set[CrossSyncFileArtifact]: + self, artifacts: set[CrossSyncOutputFile] | None = None + ) -> set[CrossSyncOutputFile]: """ - Called to run a file through the transformer. If any classes are marked with a CrossSync decorator, - they will be transformed and added to an artifact for the output file + Called to run a file through the ast transformer. + + If the file contains any classes marked with CrossSync.export_sync, the + classes will be processed according to the decorator arguments, and + a set of CrossSyncOutputFile objects will be returned for each output file. + + If no CrossSync annotations are found, no changes will occur and an + empty set will be returned """ tree = ast.parse(open(self.in_path).read()) self._artifact_dict = {f.file_path: f for f in artifacts or []} self.imports = self._get_imports(tree) self.visit(tree) - found = set(self._artifact_dict.values()) - if artifacts is not None: - artifacts.update(found) - return found + # return set of new artifacts + return set(self._artifact_dict.values()).difference(artifacts or []) def visit_ClassDef(self, node): """ Called for each class in file. If class has a CrossSync decorator, it will be transformed - according to the decorator arguments + according to the decorator arguments. Otherwise, no changes will occur """ try: for decorator in node.decorator_list: @@ -217,7 +221,7 @@ def visit_ClassDef(self, node): sync_cls_name = sync_path.rsplit(".", 1)[-1] # find the artifact file for the save location output_artifact = self._artifact_dict.get( - out_file, CrossSyncFileArtifact(out_file) + out_file, CrossSyncOutputFile(out_file) ) # write converted class details if not already present if sync_cls_name not in output_artifact.contained_classes: @@ -246,7 +250,9 @@ def _transform_class( **kwargs, ) -> ast.ClassDef: """ - Transform async class into sync one, by running through a series of transformers + Transform async class into sync one, by applying the following ast transformations: + - SymbolReplacer: to replace any class-level symbols specified in CrossSync.export_sync(replace_symbols={}) decorator + - CrossSyncMethodDecoratorHandler: to visit each method in the class and apply any CrossSync decorators found """ # update name cls_ast.name = new_name @@ -267,6 +273,8 @@ def _get_imports( ) -> list[ast.Import | ast.ImportFrom | ast.Try | ast.If]: """ Grab the imports from the top of the file + + raw imports, as well as try and if statements at the top level are included """ imports = [] for node in tree.body: From 39ae9078bdfb0aab36ae5274d13500e907394995 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 14:59:06 -0700 Subject: [PATCH 171/360] added back conversion annotations This reverts commit 245bd0894a61b3f309d79cdfa040dd344eabc29f. --- .../bigtable/data/_async/_mutate_rows.py | 11 ++++ .../cloud/bigtable/data/_async/_read_rows.py | 8 +++ google/cloud/bigtable/data/_async/client.py | 43 +++++++++++++ .../bigtable/data/_async/mutations_batcher.py | 22 +++++++ .../cloud/bigtable/data/_sync/cross_sync.py | 11 ++++ tests/system/data/test_system_async.py | 11 ++++ tests/unit/data/_async/test__mutate_rows.py | 4 ++ tests/unit/data/_async/test__read_rows.py | 11 ++++ tests/unit/data/_async/test_client.py | 64 +++++++++++++++++++ .../data/_async/test_mutations_batcher.py | 15 +++++ .../data/_async/test_read_rows_acceptance.py | 11 ++++ 11 files changed, 211 insertions(+) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 7c40b492c..3feb64b68 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -40,6 +40,9 @@ ) +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", +) class _MutateRowsOperationAsync: """ MutateRowsOperation manages the logic of sending a set of row mutations, @@ -59,6 +62,12 @@ class _MutateRowsOperationAsync: If not specified, the request will run until operation_timeout is reached. """ + @CrossSync.convert( + replace_symbols={ + "BigtableAsyncClient": "BigtableClient", + "TableAsync": "Table", + } + ) def __init__( self, gapic_client: "BigtableAsyncClient", @@ -108,6 +117,7 @@ def __init__( self.remaining_indices = list(range(len(self.mutations))) self.errors: dict[int, list[Exception]] = {} + @CrossSync.convert async def start(self): """ Start the operation, and run until completion @@ -144,6 +154,7 @@ async def start(self): all_errors, len(self.mutations) ) + @CrossSync.convert async def _run_attempt(self): """ Run a single attempt of the mutate_rows rpc. diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index dfc9c1adb..989430f64 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -45,6 +45,9 @@ from google.cloud.bigtable.data._async.client import TableAsync +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", +) class _ReadRowsOperationAsync: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream @@ -76,6 +79,7 @@ class _ReadRowsOperationAsync: "_remaining_count", ) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def __init__( self, query: ReadRowsQuery, @@ -156,6 +160,7 @@ def _read_rows_attempt(self) -> CrossSync.Iterable[Row]: chunked_stream = self.chunk_stream(gapic_stream) return self.merge_rows(chunked_stream) + @CrossSync.convert async def chunk_stream( self, stream: CrossSync.Awaitable[CrossSync.Iterable[ReadRowsResponsePB]] ) -> CrossSync.Iterable[ReadRowsResponsePB.CellChunk]: @@ -208,6 +213,9 @@ async def chunk_stream( current_key = None @staticmethod + @CrossSync.convert( + replace_symbols={"__aiter__": "__iter__", "__anext__": "__next__"} + ) async def merge_rows( chunks: CrossSync.Iterable[ReadRowsResponsePB.CellChunk] | None, ) -> CrossSync.Iterable[Row]: diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index a2ac7f2cc..64efe25b5 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -97,7 +97,17 @@ from google.cloud.bigtable.data._helpers import ShardedQuery +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync.client.BigtableDataClient", +) class BigtableDataClientAsync(ClientWithProject): + @CrossSync.convert( + replace_symbols={ + "BigtableAsyncClient": "BigtableClient", + "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", + "AsyncPooledChannel": "PooledChannel", + } + ) def __init__( self, *, @@ -238,6 +248,7 @@ def _start_background_channel_refresh(self) -> None: ) self._channel_refresh_tasks.append(refresh_task) + @CrossSync.convert async def close(self, timeout: float | None = 2.0): """ Cancel all background tasks @@ -251,6 +262,7 @@ async def close(self, timeout: float | None = 2.0): await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) self._channel_refresh_tasks = [] + @CrossSync.convert async def _ping_and_warm_instances( self, channel: Channel, instance_key: _WarmedInstanceKey | None = None ) -> list[BaseException | None]: @@ -292,6 +304,7 @@ async def _ping_and_warm_instances( ) return [r or None for r in result_list] + @CrossSync.convert async def _manage_channel( self, channel_idx: int, @@ -351,6 +364,7 @@ async def _manage_channel( next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = next_refresh - (time.monotonic() - start_timestamp) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: """ Registers an instance with the client, and warms the channel pool @@ -381,6 +395,7 @@ async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: # refresh tasks aren't active. start them as background tasks self._start_background_channel_refresh() + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def _remove_instance_registration( self, instance_id: str, owner: TableAsync ) -> bool: @@ -411,6 +426,7 @@ async def _remove_instance_registration( except KeyError: return False + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: """ Returns a table instance for making data API requests. All arguments are passed @@ -452,15 +468,18 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs """ return TableAsync(self, instance_id, table_id, *args, **kwargs) + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): self._start_background_channel_refresh() return self + @CrossSync.convert(sync_name="__exit__", replace_symbols={"__aexit__": "__exit__"}) async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) +@CrossSync.export_sync(path="google.cloud.bigtable.data._sync.client.Table") class TableAsync: """ Main Data API surface @@ -469,6 +488,9 @@ class TableAsync: each call """ + @CrossSync.convert( + replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} + ) def __init__( self, client: BigtableDataClientAsync, @@ -589,6 +611,12 @@ def __init__( f"{self.__class__.__name__} must be created within an async event loop context." ) from e + @CrossSync.convert( + replace_symbols={ + "AsyncIterable": "Iterable", + "_ReadRowsOperationAsync": "_ReadRowsOperation", + } + ) async def read_rows_stream( self, query: ReadRowsQuery, @@ -639,6 +667,7 @@ async def read_rows_stream( ) return row_merger.start_operation() + @CrossSync.convert async def read_rows( self, query: ReadRowsQuery, @@ -686,6 +715,7 @@ async def read_rows( ) return [row async for row in row_generator] + @CrossSync.convert async def read_row( self, row_key: str | bytes, @@ -735,6 +765,7 @@ async def read_row( return None return results[0] + @CrossSync.convert async def read_rows_sharded( self, sharded_query: ShardedQuery, @@ -834,6 +865,7 @@ async def read_rows_with_semaphore(query): ) return results_list + @CrossSync.convert async def row_exists( self, row_key: str | bytes, @@ -882,6 +914,7 @@ async def row_exists( ) return len(results) > 0 + @CrossSync.convert async def sample_row_keys( self, *, @@ -954,6 +987,7 @@ async def execute_rpc(): exception_factory=_retry_exception_factory, ) + @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) def mutations_batcher( self, *, @@ -1003,6 +1037,7 @@ def mutations_batcher( batch_retryable_errors=batch_retryable_errors, ) + @CrossSync.convert async def mutate_row( self, row_key: str | bytes, @@ -1081,6 +1116,9 @@ async def mutate_row( exception_factory=_retry_exception_factory, ) + @CrossSync.convert( + replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"} + ) async def bulk_mutate_rows( self, mutation_entries: list[RowMutationEntry], @@ -1136,6 +1174,7 @@ async def bulk_mutate_rows( ) await operation.start() + @CrossSync.convert async def check_and_mutate_row( self, row_key: str | bytes, @@ -1202,6 +1241,7 @@ async def check_and_mutate_row( ) return result.predicate_matched + @CrossSync.convert async def read_modify_write_row( self, row_key: str | bytes, @@ -1252,6 +1292,7 @@ async def read_modify_write_row( # construct Row from result return Row._from_pb(result.row) + @CrossSync.convert async def close(self): """ Called to close the Table instance and release any resources held by it. @@ -1260,6 +1301,7 @@ async def close(self): self._register_instance_future.cancel() await self.client._remove_instance_registration(self.instance_id, self) + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """ Implement async context manager protocol @@ -1271,6 +1313,7 @@ async def __aenter__(self): await self._register_instance_future return self + @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc_val, exc_tb): """ Implement async context manager protocol diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index caa35425c..53aea4db0 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -45,6 +45,9 @@ from google.cloud.bigtable.data._async.client import TableAsync +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync.mutations_batcher._FlowControl" +) class _FlowControlAsync: """ Manages flow control for batched mutations. Mutations are registered against @@ -101,6 +104,7 @@ def _has_capacity(self, additional_count: int, additional_size: int) -> bool: new_count = self._in_flight_mutation_count + additional_count return new_size <= acceptable_size and new_count <= acceptable_count + @CrossSync.convert async def remove_from_flow( self, mutations: RowMutationEntry | list[RowMutationEntry] ) -> None: @@ -122,6 +126,7 @@ async def remove_from_flow( async with self._capacity_condition: self._capacity_condition.notify_all() + @CrossSync.convert async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): """ Generator function that registers mutations with flow control. As mutations @@ -171,6 +176,10 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] yield mutations[start_idx:end_idx] +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", + mypy_ignore=["unreachable"], +) class MutationsBatcherAsync: """ Allows users to send batches using context manager API: @@ -202,6 +211,9 @@ class MutationsBatcherAsync: Defaults to the Table's default_mutate_rows_retryable_errors. """ + @CrossSync.convert( + replace_symbols={"TableAsync": "Table", "_FlowControlAsync": "_FlowControl"} + ) def __init__( self, table: TableAsync, @@ -257,6 +269,7 @@ def __init__( # clean up on program exit atexit.register(self._on_exit) + @CrossSync.convert async def _timer_routine(self, interval: float | None) -> None: """ Set up a background task to flush the batcher every interval seconds @@ -277,6 +290,7 @@ async def _timer_routine(self, interval: float | None) -> None: if not self._closed.is_set() and self._staged_entries: self._schedule_flush() + @CrossSync.convert async def append(self, mutation_entry: RowMutationEntry): """ Add a new set of mutations to the internal queue @@ -326,6 +340,7 @@ def _schedule_flush(self) -> CrossSync.Future[None] | None: return new_task return None + @CrossSync.convert async def _flush_internal(self, new_entries: list[RowMutationEntry]): """ Flushes a set of mutations to the server, and updates internal state @@ -346,6 +361,9 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): self._entries_processed_since_last_raise += len(new_entries) self._add_exceptions(found_exceptions) + @CrossSync.convert( + replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"} + ) async def _execute_mutate_rows( self, batch: list[RowMutationEntry] ) -> list[FailedMutationEntryError]: @@ -426,10 +444,12 @@ def _raise_exceptions(self): entry_count=entry_count, ) + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """Allow use of context manager API""" return self + @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc, tb): """ Allow use of context manager API. @@ -446,6 +466,7 @@ def closed(self) -> bool: """ return self._closed.is_set() + @CrossSync.convert async def close(self): """ Flush queue and clean up resources @@ -473,6 +494,7 @@ def _on_exit(self): ) @staticmethod + @CrossSync.convert async def _wait_for_batch_results( *tasks: CrossSync.Future[list[FailedMutationEntryError]] | CrossSync.Future[None], diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index b4fc4929d..e73fa6322 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -220,6 +220,17 @@ class CrossSync(metaclass=_DecoratorMeta): Generator: TypeAlias = AsyncGenerator _decorators: list[AstDecorator] = [ + AstDecorator("export_sync", # decorate classes to convert + required_keywords=["path"], # otput path for generated sync class + replace_symbols={}, # replace specific symbols across entire class + mypy_ignore=(), # set of mypy error codes to ignore in output file + include_file_imports=True # when True, import statements from top of file will be included in output file + ), + AstDecorator("convert", # decorate methods to convert from async to sync + sync_name=None, # use a new name for the sync class + replace_symbols={}, # replace specific symbols within the function + ), + AstDecorator("drop_method"), # decorate methods to drop in sync version of class AstDecorator( "pytest", async_impl=pytest_mark_asyncio ), # decorate test methods to run with pytest-asyncio diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index cbd8d7605..5c37d64d2 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -30,6 +30,7 @@ from google.cloud.bigtable.data._async.client import BigtableDataClientAsync +@CrossSync.export_sync(path="tests.system.data.test_system.TempRowBuilder") class TempRowBuilderAsync: """ Used to add rows to a table for testing purposes. @@ -39,6 +40,7 @@ def __init__(self, table): self.rows = [] self.table = table + @CrossSync.convert async def add_row( self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" ): @@ -62,6 +64,7 @@ async def add_row( await self.table.client._gapic_client.mutate_row(request) self.rows.append(row_key) + @CrossSync.convert async def delete_rows(self): if self.rows: request = { @@ -74,13 +77,16 @@ async def delete_rows(self): await self.table.client._gapic_client.mutate_rows(request) +@CrossSync.export_sync(path="tests.system.data.test_system.TestSystem") class TestSystemAsync: + @CrossSync.convert(replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}) @CrossSync.pytest_fixture(scope="session") async def client(self): project = os.getenv("GOOGLE_CLOUD_PROJECT") or None async with BigtableDataClientAsync(project=project, pool_size=4) as client: yield client + @CrossSync.convert @CrossSync.pytest_fixture(scope="session") async def table(self, client, table_id, instance_id): async with client.get_table( @@ -127,6 +133,7 @@ def cluster_config(self, project_id): } return cluster + @CrossSync.convert @pytest.mark.usefixtures("table") async def _retrieve_cell_value(self, table, row_key): """ @@ -140,6 +147,7 @@ async def _retrieve_cell_value(self, table, row_key): cell = row.cells[0] return cell.value + @CrossSync.convert async def _create_row_and_mutation( self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" ): @@ -160,6 +168,7 @@ async def _create_row_and_mutation( mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) return row_key, mutation + @CrossSync.convert(replace_symbols={"TempRowBuilderAsync": "TempRowBuilder"}) @CrossSync.pytest_fixture(scope="function") async def temp_rows(self, table): builder = TempRowBuilderAsync(table) @@ -650,6 +659,7 @@ async def test_check_and_mutate_empty_request(self, client, table): assert "No mutations provided" in str(e.value) @pytest.mark.usefixtures("table") + @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) @CrossSync.Retry( predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 ) @@ -838,6 +848,7 @@ async def test_read_rows_with_filter(self, table, temp_rows): assert row[0].labels == [expected_label] @pytest.mark.usefixtures("table") + @CrossSync.convert(replace_symbols={"__anext__": "__next__", "aclose": "close"}) @CrossSync.pytest async def test_read_rows_stream_close(self, table, temp_rows): """ diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index b7016be81..292cbd692 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -28,6 +28,9 @@ import mock # type: ignore +@CrossSync.export_sync( + path="tests.unit.data._sync.test__mutate_rows.TestMutateRowsOperation", +) class TestMutateRowsOperation: def _target_class(self): if CrossSync.is_async: @@ -59,6 +62,7 @@ def _make_mutation(self, count=1, size=1): mutation.mutations = [mock.Mock()] * count return mutation + @CrossSync.convert async def _mock_stream(self, mutation_list, error_dict): for idx, entry in enumerate(mutation_list): code = error_dict.get(idx, 0) diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 5d9957e1f..ff55ffd20 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -34,6 +34,9 @@ TEST_LABELS = ["label1", "label2"] +@CrossSync.export_sync( + path="tests.unit.data._sync.test__read_rows.TestReadRowsOperation", +) class TestReadRowsOperationAsync: """ Tests helper functions in the ReadRowsOperation class @@ -42,6 +45,9 @@ class TestReadRowsOperationAsync: """ @staticmethod + @CrossSync.convert( + replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} + ) def _get_target_class(): return _ReadRowsOperationAsync @@ -327,6 +333,10 @@ async def mock_stream(): assert "emit count exceeds row limit" in str(e.value) @CrossSync.pytest + @CrossSync.convert( + sync_name="test_close", + replace_symbols={"aclose": "close", "__anext__": "__next__"}, + ) async def test_aclose(self): """ should be able to close a stream safely with aclose. @@ -356,6 +366,7 @@ async def mock_stream(): await wrapped_gen.__anext__() @CrossSync.pytest + @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) async def test_retryable_ignore_repeated_rows(self): """ Duplicate rows should cause an invalid chunk error diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 5370f35d3..0f5775fac 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -76,8 +76,21 @@ ) +@CrossSync.export_sync( + path="tests.unit.data._sync.test_client.TestBigtableDataClient", + replace_symbols={ + "TestTableAsync": "TestTable", + "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", + "grpc_helpers_async": "grpc_helpers", + "PooledChannelAsync": "PooledChannel", + "BigtableAsyncClient": "BigtableClient", + }, +) class TestBigtableDataClientAsync: @staticmethod + @CrossSync.convert( + replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} + ) def _get_target_class(): return BigtableDataClientAsync @@ -287,6 +300,7 @@ async def test_channel_pool_replace(self): assert client.transport._grpc_channel._pool[i] != start_pool[i] await client.close() + @CrossSync.drop_method @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test__start_background_channel_refresh_sync(self): # should raise RuntimeError if called in a sync context @@ -330,6 +344,7 @@ async def test__start_background_channel_refresh(self, pool_size): ping_and_warm.assert_any_call(channel) await client.close() + @CrossSync.drop_method @CrossSync.pytest @pytest.mark.skipif( sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" @@ -1110,6 +1125,7 @@ async def test_context_manager(self): # actually close the client await true_close + @CrossSync.drop_method def test_client_ctor_sync(self): # initializing client in a sync context should raise RuntimeError @@ -1125,11 +1141,16 @@ def test_client_ctor_sync(self): assert client._channel_refresh_tasks == [] +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestTable") class TestTableAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @staticmethod + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def _get_target_class(): return TableAsync @@ -1251,6 +1272,7 @@ async def test_table_ctor_invalid_timeout_values(self): assert "operation_timeout must be greater than 0" in str(e.value) await client.close() + @CrossSync.drop_method def test_table_ctor_sync(self): # initializing client in a sync context should raise RuntimeError client = mock.Mock() @@ -1400,6 +1422,7 @@ async def test_customizable_retryable_errors( ) @pytest.mark.parametrize("include_app_profile", [True, False]) @CrossSync.pytest + @CrossSync.convert(replace_symbols={"BigtableAsyncClient": "BigtableClient"}) async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): """check that all requests attach proper metadata headers""" profile = "profile" if include_app_profile else None @@ -1432,18 +1455,26 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ assert "app_profile_id=" not in goog_metadata +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRows") class TestReadRowsAsync: """ Tests for table.read_rows and related methods. """ @staticmethod + @CrossSync.convert( + replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} + ) def _get_operation_class(): return _ReadRowsOperationAsync + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @CrossSync.convert(replace_symbols={"TestTableAsync": "TestTable"}) def _make_table(self, *args, **kwargs): client_mock = mock.Mock() client_mock._register_instance.side_effect = ( @@ -1491,6 +1522,7 @@ def _make_chunk(*args, **kwargs): return ReadRowsResponse.CellChunk(*args, **kwargs) @staticmethod + @CrossSync.convert async def _make_gapic_stream( chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0, @@ -1529,6 +1561,7 @@ def cancel(self): return mock_stream(chunk_list, sleep_time) + @CrossSync.convert async def execute_fn(self, table, *args, **kwargs): return await table.read_rows(*args, **kwargs) @@ -1940,7 +1973,11 @@ async def test_row_exists(self, return_value, expected_result): assert query.filter._to_dict() == expected_filter +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRowsSharded") class TestReadRowsShardedAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -1953,6 +1990,7 @@ async def test_read_rows_sharded_empty_query(self): assert "empty sharded_query" in str(exc.value) @CrossSync.pytest + @CrossSync.convert(replace_symbols={"TestReadRowsAsync": "TestReadRows"}) async def test_read_rows_sharded_multiple_queries(self): """ Test with multiple queries. Should return results from both @@ -2160,10 +2198,15 @@ async def mock_call(*args, **kwargs): ) +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestSampleRowKeys") class TestSampleRowKeysAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @CrossSync.convert async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): from google.cloud.bigtable_v2.types import SampleRowKeysResponse @@ -2311,7 +2354,13 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio await table.sample_row_keys() +@CrossSync.export_sync( + path="tests.unit.data._sync.test_client.TestMutateRow", +) class TestMutateRowAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -2486,10 +2535,17 @@ async def test_mutate_row_no_mutations(self, mutations): assert e.value.args[0] == "No mutations provided" +@CrossSync.export_sync( + path="tests.unit.data._sync.test_client.TestBulkMutateRows", +) class TestBulkMutateRowsAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) + @CrossSync.convert async def _mock_response(self, response_list): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -2865,7 +2921,11 @@ async def test_bulk_mutate_error_recovery(self): await table.bulk_mutate_rows(entries, operation_timeout=1000) +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestCheckAndMutateRow") class TestCheckAndMutateRowAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) @@ -3016,7 +3076,11 @@ async def test_check_and_mutate_mutations_parsing(self): ) +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadModifyWriteRow") class TestReadModifyWriteRowAsync: + @CrossSync.convert( + replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} + ) def _make_client(self, *args, **kwargs): return TestBigtableDataClientAsync._make_client(*args, **kwargs) diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index a26dcfd64..cbf84e798 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -43,8 +43,10 @@ import mock # type: ignore +@CrossSync.export_sync(path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl") class Test_FlowControl: @staticmethod + @CrossSync.convert(replace_symbols={"_FlowControlAsync": "_FlowControl"}) def _target_class(): return _FlowControlAsync @@ -311,7 +313,11 @@ async def test_add_to_flow_oversize(self): assert len(count_results) == 1 +@CrossSync.export_sync( + path="tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" +) class TestMutationsBatcherAsync: + @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) def _get_target_class(self): return MutationsBatcherAsync @@ -473,6 +479,7 @@ async def test_ctor_invalid_values(self): self._make_one(batch_attempt_timeout=-1) assert "attempt_timeout must be greater than 0" in str(e.value) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def test_default_argument_consistency(self): """ We supply default arguments in MutationsBatcherAsync.__init__, and in @@ -893,6 +900,7 @@ async def gen(x): instance._oldest_exceptions.clear() instance._newest_exceptions.clear() + @CrossSync.convert async def _mock_gapic_return(self, num=5): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -1008,12 +1016,18 @@ async def test__raise_exceptions(self): instance._raise_exceptions() @CrossSync.pytest + @CrossSync.convert( + sync_name="test___enter__", replace_symbols={"__aenter__": "__enter__"} + ) async def test___aenter__(self): """Should return self""" async with self._make_one() as instance: assert await instance.__aenter__() == instance @CrossSync.pytest + @CrossSync.convert( + sync_name="test___exit__", replace_symbols={"__aexit__": "__exit__"} + ) async def test___aexit__(self): """aexit should call close""" async with self._make_one() as instance: @@ -1197,6 +1211,7 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): ([4], [core_exceptions.DeadlineExceeded]), ], ) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def test_customizable_retryable_errors( self, input_retryables, expected_retryables ): diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 15b181637..6dc4ed79c 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -39,12 +39,21 @@ from google.cloud.bigtable.data._sync.client import BigtableDataClient # noqa: F401 +@CrossSync.export_sync( + path="tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", +) class TestReadRowsAcceptanceAsync: @staticmethod + @CrossSync.convert( + replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} + ) def _get_operation_class(): return _ReadRowsOperationAsync @staticmethod + @CrossSync.convert( + replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} + ) def _get_client_class(): return BigtableDataClientAsync @@ -74,9 +83,11 @@ def extract_results_from_row(row: Row): return results @staticmethod + @CrossSync.convert async def _coro_wrapper(stream): return stream + @CrossSync.convert async def _process_chunks(self, *chunks): async def _row_stream(): yield ReadRowsResponse(chunks=chunks) From fd1fb711723ebe7c5b9d56352845f3602b2a6d77 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 15:14:37 -0700 Subject: [PATCH 172/360] fixed decorator sync_impl call --- google/cloud/bigtable/data/_sync/cross_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index e73fa6322..f8d7202a8 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -116,7 +116,7 @@ def __call__(self, *args, **kwargs): raise ValueError(f"Invalid keyword argument: {kwarg}") # if async_impl is provided, use the given decorator function if self.async_impl: - return self.async_impl(**{**self.default_kwargs, **kwargs}) + return self.async_impl(*args, **{**self.default_kwargs, **kwargs}) # if no arguments, args[0] will hold the function to be decorated # return the function as is if len(args) == 1 and callable(args[0]): From adb092eb019abb4c31d0c404d39ed27246603e5b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 16:27:41 -0700 Subject: [PATCH 173/360] use add_mapping in place of replace_symbols --- .../bigtable/data/_async/_mutate_rows.py | 13 +++---- .../cloud/bigtable/data/_async/_read_rows.py | 4 +- google/cloud/bigtable/data/_async/client.py | 38 ++++++++----------- .../bigtable/data/_async/mutations_batcher.py | 8 ++-- .../cloud/bigtable/data/_sync/cross_sync.py | 17 +++++++++ .../data/_async/test_mutations_batcher.py | 14 +------ 6 files changed, 45 insertions(+), 49 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 3feb64b68..f3ced4849 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -38,6 +38,8 @@ from google.cloud.bigtable_v2.services.bigtable.async_client import ( BigtableAsyncClient, ) + CrossSync.add_mapping("Table", TableAsync) + CrossSync.add_mapping("GapicClient", BigtableAsyncClient) @CrossSync.export_sync( @@ -62,16 +64,11 @@ class _MutateRowsOperationAsync: If not specified, the request will run until operation_timeout is reached. """ - @CrossSync.convert( - replace_symbols={ - "BigtableAsyncClient": "BigtableClient", - "TableAsync": "Table", - } - ) + @CrossSync.convert def __init__( self, - gapic_client: "BigtableAsyncClient", - table: "TableAsync", + gapic_client: "CrossSync.GapicClient", + table: "CrossSync.Table", mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 989430f64..d6b8f213c 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -43,6 +43,7 @@ if TYPE_CHECKING: if CrossSync.is_async: from google.cloud.bigtable.data._async.client import TableAsync + CrossSync.add_mapping("Table", TableAsync) @CrossSync.export_sync( @@ -79,11 +80,10 @@ class _ReadRowsOperationAsync: "_remaining_count", ) - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def __init__( self, query: ReadRowsQuery, - table: "TableAsync", + table: "CrossSync.Table", operation_timeout: float, attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 64efe25b5..dd5644fe0 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -91,6 +91,13 @@ from google.cloud.bigtable_v2.services.bigtable.async_client import ( BigtableAsyncClient, ) + # define file-specific cross-sync replacements + CrossSync.add_mapping("GapicClient", BigtableAsyncClient) + CrossSync.add_mapping("PooledTransport", PooledBigtableGrpcAsyncIOTransport) + CrossSync.add_mapping("PooledChannel", AsyncPooledChannel) + CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) + CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) + if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples @@ -101,13 +108,7 @@ path="google.cloud.bigtable.data._sync.client.BigtableDataClient", ) class BigtableDataClientAsync(ClientWithProject): - @CrossSync.convert( - replace_symbols={ - "BigtableAsyncClient": "BigtableClient", - "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", - "AsyncPooledChannel": "PooledChannel", - } - ) + @CrossSync.convert def __init__( self, *, @@ -143,7 +144,7 @@ def __init__( """ # set up transport in registry transport_str = f"bt-{self._client_version()}-{pool_size}" - transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) + transport = CrossSync.PooledTransport.with_fixed_size(pool_size) BigtableClientMeta._transport_registry[transport_str] = transport # set up client info headers for veneer library client_info = DEFAULT_CLIENT_INFO @@ -168,7 +169,7 @@ def __init__( project=project, client_options=client_options, ) - self._gapic_client = BigtableAsyncClient( + self._gapic_client = CrossSync.GapicClient( transport=transport_str, credentials=credentials, client_options=client_options, @@ -176,7 +177,7 @@ def __init__( ) self._is_closed = CrossSync.Event() self.transport = cast( - PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport + CrossSync.PooledTransport, self._gapic_client.transport ) # keep track of active instances to for warmup on channel refresh self._active_instances: Set[_WarmedInstanceKey] = set() @@ -195,7 +196,7 @@ def __init__( RuntimeWarning, stacklevel=2, ) - self.transport._grpc_channel = AsyncPooledChannel( + self.transport._grpc_channel = CrossSync.PooledChannel( pool_size=pool_size, host=self._emulator_host, insecure=True, @@ -611,12 +612,7 @@ def __init__( f"{self.__class__.__name__} must be created within an async event loop context." ) from e - @CrossSync.convert( - replace_symbols={ - "AsyncIterable": "Iterable", - "_ReadRowsOperationAsync": "_ReadRowsOperation", - } - ) + @CrossSync.convert(replace_symbols={"AsyncIterable": "Iterable"}) async def read_rows_stream( self, query: ReadRowsQuery, @@ -658,7 +654,7 @@ async def read_rows_stream( ) retryable_excs = _get_retryable_errors(retryable_errors, self) - row_merger = _ReadRowsOperationAsync( + row_merger = CrossSync._ReadRowsOperation( query, self, operation_timeout=operation_timeout, @@ -1116,9 +1112,7 @@ async def mutate_row( exception_factory=_retry_exception_factory, ) - @CrossSync.convert( - replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"} - ) + @CrossSync.convert async def bulk_mutate_rows( self, mutation_entries: list[RowMutationEntry], @@ -1164,7 +1158,7 @@ async def bulk_mutate_rows( ) retryable_excs = _get_retryable_errors(retryable_errors, self) - operation = _MutateRowsOperationAsync( + operation = CrossSync._MutateRowsOperation( self.client._gapic_client, self, mutation_entries, diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 53aea4db0..b6415c502 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -36,7 +36,7 @@ if CrossSync.is_async: from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync - + CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -361,9 +361,7 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): self._entries_processed_since_last_raise += len(new_entries) self._add_exceptions(found_exceptions) - @CrossSync.convert( - replace_symbols={"_MutateRowsOperationAsync": "_MutateRowsOperation"} - ) + @CrossSync.convert async def _execute_mutate_rows( self, batch: list[RowMutationEntry] ) -> list[FailedMutationEntryError]: @@ -380,7 +378,7 @@ async def _execute_mutate_rows( FailedMutationEntryError objects will not contain index information """ try: - operation = _MutateRowsOperationAsync( + operation = CrossSync._MutateRowsOperation( self._table.client._gapic_client, self._table, batch, diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index f8d7202a8..b3ededf81 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -219,6 +219,21 @@ class CrossSync(metaclass=_DecoratorMeta): Iterator: TypeAlias = AsyncIterator Generator: TypeAlias = AsyncGenerator + @classmethod + def add_mapping(cls, name, value): + """ + Add a new attribute to the CrossSync class, for replacing library-level symbols + + Raises: + - AttributeError if the attribute already exists with a different value + """ + if not hasattr(cls, name): + cls._runtime_replacements.add(name) + elif value != getattr(cls, name): + raise AttributeError(f"Conflicting assignments for CrossSync.{name}") + setattr(cls, name, value) + + # list of decorators that can be applied to classes and methods to guide code generation _decorators: list[AstDecorator] = [ AstDecorator("export_sync", # decorate classes to convert required_keywords=["path"], # otput path for generated sync class @@ -244,6 +259,8 @@ class CrossSync(metaclass=_DecoratorMeta): name=None, ), ] + # list of attributes that can be added to the CrossSync class at runtime + _runtime_replacements: set[Any] = set() @classmethod def Mock(cls, *args, **kwargs): diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index cbf84e798..df5bb0a4c 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -938,11 +938,7 @@ async def test_timer_flush_end_to_end(self): @CrossSync.pytest async def test__execute_mutate_rows(self): - if CrossSync.is_async: - mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" - else: - mutate_path = "_sync.mutations_batcher._MutateRowsOperation" - with mock.patch(f"google.cloud.bigtable.data.{mutate_path}") as mutate_rows: + with mock.patch.object(CrossSync, "_MutateRowsOperation") as mutate_rows: mutate_rows.return_value = CrossSync.Mock() start_operation = mutate_rows().start table = mock.Mock() @@ -1105,13 +1101,7 @@ async def test_timeout_args_passed(self): batch_operation_timeout and batch_attempt_timeout should be used in api calls """ - if CrossSync.is_async: - mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" - else: - mutate_path = "_sync.mutations_batcher._MutateRowsOperation" - with mock.patch( - f"google.cloud.bigtable.data.{mutate_path}", return_value=CrossSync.Mock() - ) as mutate_rows: + with mock.patch.object(CrossSync, "_MutateRowsOperation", return_value=CrossSync.Mock()) as mutate_rows: expected_operation_timeout = 17 expected_attempt_timeout = 13 async with self._make_one( From 45efa16a0b221cf014e43458972edc5aa85c0aaf Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 16:46:33 -0700 Subject: [PATCH 174/360] support automatic attribute registration --- .../bigtable/data/_async/_mutate_rows.py | 2 +- .../cloud/bigtable/data/_async/_read_rows.py | 2 +- google/cloud/bigtable/data/_async/client.py | 5 ++--- .../bigtable/data/_async/mutations_batcher.py | 11 +++++----- .../cloud/bigtable/data/_sync/cross_sync.py | 20 ++++++++++++++++++- 5 files changed, 28 insertions(+), 12 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index f3ced4849..e320a44ea 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -38,12 +38,12 @@ from google.cloud.bigtable_v2.services.bigtable.async_client import ( BigtableAsyncClient, ) - CrossSync.add_mapping("Table", TableAsync) CrossSync.add_mapping("GapicClient", BigtableAsyncClient) @CrossSync.export_sync( path="google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", + add_mapping_for_name="_MutateRowsOperation", ) class _MutateRowsOperationAsync: """ diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index d6b8f213c..b78b3160e 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -43,11 +43,11 @@ if TYPE_CHECKING: if CrossSync.is_async: from google.cloud.bigtable.data._async.client import TableAsync - CrossSync.add_mapping("Table", TableAsync) @CrossSync.export_sync( path="google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", + add_mapping_for_name="_ReadRowsOperation", ) class _ReadRowsOperationAsync: """ diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index dd5644fe0..2d1e3b629 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -95,8 +95,6 @@ CrossSync.add_mapping("GapicClient", BigtableAsyncClient) CrossSync.add_mapping("PooledTransport", PooledBigtableGrpcAsyncIOTransport) CrossSync.add_mapping("PooledChannel", AsyncPooledChannel) - CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) - CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) if TYPE_CHECKING: @@ -106,6 +104,7 @@ @CrossSync.export_sync( path="google.cloud.bigtable.data._sync.client.BigtableDataClient", + add_mapping_for_name="DataClient", ) class BigtableDataClientAsync(ClientWithProject): @CrossSync.convert @@ -480,7 +479,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) -@CrossSync.export_sync(path="google.cloud.bigtable.data._sync.client.Table") +@CrossSync.export_sync(path="google.cloud.bigtable.data._sync.client.Table", add_mapping_for_name="Table") class TableAsync: """ Main Data API surface diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index b6415c502..3a49ef051 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -36,7 +36,6 @@ if CrossSync.is_async: from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync - CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -46,7 +45,8 @@ @CrossSync.export_sync( - path="google.cloud.bigtable.data._sync.mutations_batcher._FlowControl" + path="google.cloud.bigtable.data._sync.mutations_batcher._FlowControl", + add_mapping_for_name="_FlowControl" ) class _FlowControlAsync: """ @@ -179,6 +179,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] @CrossSync.export_sync( path="google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", mypy_ignore=["unreachable"], + add_mapping_for_name="MutationsBatcher", ) class MutationsBatcherAsync: """ @@ -211,9 +212,7 @@ class MutationsBatcherAsync: Defaults to the Table's default_mutate_rows_retryable_errors. """ - @CrossSync.convert( - replace_symbols={"TableAsync": "Table", "_FlowControlAsync": "_FlowControl"} - ) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def __init__( self, table: TableAsync, @@ -239,7 +238,7 @@ def __init__( self._table = table self._staged_entries: list[RowMutationEntry] = [] self._staged_count, self._staged_bytes = 0, 0 - self._flow_control = _FlowControlAsync( + self._flow_control = CrossSync._FlowControl( flow_control_max_mutation_count, flow_control_max_bytes ) self._flush_limit_bytes = flush_limit_bytes diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index b3ededf81..194eb7524 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -72,6 +72,22 @@ def decorator(func): return decorator +def export_sync_impl(*args, **kwargs): + """ + Decorator implementation for CrossSync.export_sync + + When a called with add_mapping_for_name, CrossSync.add_mapping is called to + register the name as a CrossSync attribute + """ + new_mapping = kwargs.pop("add_mapping_for_name", None) + def decorator(cls): + if new_mapping: + # add class to mappings if requested + CrossSync.add_mapping(new_mapping, cls) + return cls + return decorator + + class AstDecorator: """ Helper class for CrossSync decorators used for guiding ast transformations. @@ -237,9 +253,11 @@ def add_mapping(cls, name, value): _decorators: list[AstDecorator] = [ AstDecorator("export_sync", # decorate classes to convert required_keywords=["path"], # otput path for generated sync class + async_impl=export_sync_impl, # apply this decorator to the function at runtime replace_symbols={}, # replace specific symbols across entire class mypy_ignore=(), # set of mypy error codes to ignore in output file - include_file_imports=True # when True, import statements from top of file will be included in output file + include_file_imports=True, # when True, import statements from top of file will be included in output file + add_mapping_for_name=None, # add a new attribute to CrossSync class with the given name ), AstDecorator("convert", # decorate methods to convert from async to sync sync_name=None, # use a new name for the sync class From e1ec9741de327acd0adeeb81835ad38fdee2b445 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 17:07:12 -0700 Subject: [PATCH 175/360] reduced replace_symbols usage in tests --- tests/system/data/test_system_async.py | 10 +- tests/unit/data/_async/test__mutate_rows.py | 13 +- tests/unit/data/_async/test__read_rows.py | 6 +- tests/unit/data/_async/test_client.py | 113 +++++++----------- .../data/_async/test_mutations_batcher.py | 14 +-- .../data/_async/test_read_rows_acceptance.py | 12 +- 6 files changed, 63 insertions(+), 105 deletions(-) diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 5c37d64d2..2822c1b6c 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -30,7 +30,7 @@ from google.cloud.bigtable.data._async.client import BigtableDataClientAsync -@CrossSync.export_sync(path="tests.system.data.test_system.TempRowBuilder") +@CrossSync.export_sync(path="tests.system.data.test_system.TempRowBuilder", add_mapping_for_name="TempRowBuilder") class TempRowBuilderAsync: """ Used to add rows to a table for testing purposes. @@ -79,11 +79,11 @@ async def delete_rows(self): @CrossSync.export_sync(path="tests.system.data.test_system.TestSystem") class TestSystemAsync: - @CrossSync.convert(replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}) + @CrossSync.convert @CrossSync.pytest_fixture(scope="session") async def client(self): project = os.getenv("GOOGLE_CLOUD_PROJECT") or None - async with BigtableDataClientAsync(project=project, pool_size=4) as client: + async with CrossSync.DataClient(project=project, pool_size=4) as client: yield client @CrossSync.convert @@ -168,10 +168,10 @@ async def _create_row_and_mutation( mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) return row_key, mutation - @CrossSync.convert(replace_symbols={"TempRowBuilderAsync": "TempRowBuilder"}) + @CrossSync.convert @CrossSync.pytest_fixture(scope="function") async def temp_rows(self, table): - builder = TempRowBuilderAsync(table) + builder = CrossSync.TempRowBuilder(table) yield builder await builder.delete_rows() diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 292cbd692..a307a7008 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -33,18 +33,7 @@ ) class TestMutateRowsOperation: def _target_class(self): - if CrossSync.is_async: - from google.cloud.bigtable.data._async._mutate_rows import ( - _MutateRowsOperationAsync, - ) - - return _MutateRowsOperationAsync - else: - from google.cloud.bigtable.data._sync._mutate_rows import ( - _MutateRowsOperation, - ) - - return _MutateRowsOperation + return CrossSync._MutateRowsOperation def _make_one(self, *args, **kwargs): if not args: diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index ff55ffd20..0ec108df9 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -45,11 +45,9 @@ class TestReadRowsOperationAsync: """ @staticmethod - @CrossSync.convert( - replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} - ) + @CrossSync.convert def _get_target_class(): - return _ReadRowsOperationAsync + return CrossSync._ReadRowsOperation def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 0f5775fac..2aede657e 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -56,6 +56,7 @@ TableAsync, BigtableDataClientAsync, ) + CrossSync.add_mapping("grpc_helpers", grpc_helpers_async) else: from google.api_core import grpc_helpers # noqa: F401 from google.cloud.bigtable_v2.services.bigtable.client import ( # noqa: F401 @@ -74,25 +75,18 @@ Table, BigtableDataClient, ) + CrossSync.add_mapping("grpc_helpers", grpc_helpers) @CrossSync.export_sync( path="tests.unit.data._sync.test_client.TestBigtableDataClient", - replace_symbols={ - "TestTableAsync": "TestTable", - "PooledBigtableGrpcAsyncIOTransport": "PooledBigtableGrpcTransport", - "grpc_helpers_async": "grpc_helpers", - "PooledChannelAsync": "PooledChannel", - "BigtableAsyncClient": "BigtableClient", - }, + add_mapping_for_name="TestBigtableDataClient", ) class TestBigtableDataClientAsync: @staticmethod - @CrossSync.convert( - replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} - ) + @CrossSync.convert def _get_target_class(): - return BigtableDataClientAsync + return CrossSync.DataClient @classmethod def _make_client(cls, *args, use_emulator=True, **kwargs): @@ -145,7 +139,7 @@ async def test_ctor_super_inits(self): options_parsed = client_options_lib.from_dict(client_options) asyncio_portion = "-async" if CrossSync.is_async else "" transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" - with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: + with mock.patch.object(CrossSync.GapicClient, "__init__") as bigtable_client_init: bigtable_client_init.return_value = None with mock.patch.object( ClientWithProject, "__init__" @@ -179,7 +173,7 @@ async def test_ctor_dict_options(self): from google.api_core.client_options import ClientOptions client_options = {"api_endpoint": "foo.bar:1234"} - with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: + with mock.patch.object(CrossSync.GapicClient, "__init__") as bigtable_client_init: try: self._make_client(client_options=client_options) except TypeError: @@ -233,7 +227,7 @@ async def test_veneer_grpc_headers(self): async def test_channel_pool_creation(self): pool_size = 14 with mock.patch.object( - grpc_helpers_async, "create_channel", CrossSync.Mock() + CrossSync.grpc_helpers, "create_channel", CrossSync.Mock() ) as create_channel: client = self._make_client(project="project-id", pool_size=pool_size) assert create_channel.call_count == pool_size @@ -248,7 +242,7 @@ async def test_channel_pool_creation(self): @CrossSync.pytest async def test_channel_pool_rotation(self): pool_size = 7 - with mock.patch.object(PooledChannelAsync, "next_channel") as next_channel: + with mock.patch.object(CrossSync.PooledChannel, "next_channel") as next_channel: client = self._make_client(project="project-id", pool_size=pool_size) assert len(client.transport._grpc_channel._pool) == pool_size next_channel.reset_mock() @@ -643,7 +637,7 @@ async def test__manage_channel_refresh(self, num_cycles): new_channel = grpc_lib.insecure_channel("localhost:8080") with mock.patch.object( - PooledBigtableGrpcAsyncIOTransport, "replace_channel" + CrossSync.PooledTransport, "replace_channel" ) as replace_channel: sleep_tuple = ( (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") @@ -653,7 +647,7 @@ async def test__manage_channel_refresh(self, num_cycles): asyncio.CancelledError ] with mock.patch.object( - grpc_helpers_async, "create_channel" + CrossSync.grpc_helpers, "create_channel" ) as create_channel: create_channel.return_value = new_channel with mock.patch.object( @@ -962,7 +956,7 @@ async def test_get_table(self): expected_app_profile_id, ) await CrossSync.yield_to_event_loop() - assert isinstance(table, TestTableAsync._get_target_class()) + assert isinstance(table, CrossSync.TestTable._get_target_class()) assert table.table_id == expected_table_id assert ( table.table_name @@ -989,7 +983,7 @@ async def test_get_table_arg_passthrough(self): """ async with self._make_client(project="project-id") as client: with mock.patch.object( - TestTableAsync._get_target_class(), "__init__" + CrossSync.TestTable._get_target_class(), "__init__" ) as mock_constructor: mock_constructor.return_value = None assert not client._active_instances @@ -1025,7 +1019,7 @@ async def test_get_table_context_manager(self): expected_project_id = "project-id" with mock.patch.object( - TestTableAsync._get_target_class(), "close" + CrossSync.TestTable._get_target_class(), "close" ) as close_mock: async with self._make_client(project=expected_project_id) as client: async with client.get_table( @@ -1034,7 +1028,7 @@ async def test_get_table_context_manager(self): expected_app_profile_id, ) as table: await CrossSync.yield_to_event_loop() - assert isinstance(table, TestTableAsync._get_target_class()) + assert isinstance(table, CrossSync.TestTable._get_target_class()) assert table.table_id == expected_table_id assert ( table.table_name @@ -1082,7 +1076,7 @@ async def test_close(self): for task in client._channel_refresh_tasks: assert not task.done() with mock.patch.object( - PooledBigtableGrpcAsyncIOTransport, "close", CrossSync.Mock() + CrossSync.PooledTransport, "close", CrossSync.Mock() ) as close_mock: await client.close() close_mock.assert_called_once() @@ -1141,18 +1135,16 @@ def test_client_ctor_sync(self): assert client._channel_refresh_tasks == [] -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestTable") +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestTable", add_mapping_for_name="TestTable") class TestTableAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) + @CrossSync.convert def _make_client(self, *args, **kwargs): - return TestBigtableDataClientAsync._make_client(*args, **kwargs) + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) @staticmethod - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) + @CrossSync.convert def _get_target_class(): - return TableAsync + return CrossSync.Table @CrossSync.pytest async def test_table_ctor(self): @@ -1422,12 +1414,12 @@ async def test_customizable_retryable_errors( ) @pytest.mark.parametrize("include_app_profile", [True, False]) @CrossSync.pytest - @CrossSync.convert(replace_symbols={"BigtableAsyncClient": "BigtableClient"}) + @CrossSync.convert async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): """check that all requests attach proper metadata headers""" profile = "profile" if include_app_profile else None with mock.patch.object( - BigtableAsyncClient, gapic_fn, CrossSync.Mock() + CrossSync.GapicClient, gapic_fn, CrossSync.Mock() ) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") async with self._make_client() as client: @@ -1455,26 +1447,22 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ assert "app_profile_id=" not in goog_metadata -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRows") +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRows", add_mapping_for_name="TestReadRows") class TestReadRowsAsync: """ Tests for table.read_rows and related methods. """ @staticmethod - @CrossSync.convert( - replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} - ) + @CrossSync.convert def _get_operation_class(): - return _ReadRowsOperationAsync + return CrossSync._ReadRowsOperation - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) + @CrossSync.convert def _make_client(self, *args, **kwargs): - return TestBigtableDataClientAsync._make_client(*args, **kwargs) + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) - @CrossSync.convert(replace_symbols={"TestTableAsync": "TestTable"}) + @CrossSync.convert def _make_table(self, *args, **kwargs): client_mock = mock.Mock() client_mock._register_instance.side_effect = ( @@ -1491,7 +1479,7 @@ def _make_table(self, *args, **kwargs): ) client_mock._gapic_client.table_path.return_value = kwargs["table_id"] client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] - return TestTableAsync._get_target_class()(client_mock, *args, **kwargs) + return CrossSync.TestTable._get_target_class()(client_mock, *args, **kwargs) def _make_stats(self): from google.cloud.bigtable_v2.types import RequestStats @@ -1975,11 +1963,9 @@ async def test_row_exists(self, return_value, expected_result): @CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRowsSharded") class TestReadRowsShardedAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) + @CrossSync.convert def _make_client(self, *args, **kwargs): - return TestBigtableDataClientAsync._make_client(*args, **kwargs) + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) @CrossSync.pytest async def test_read_rows_sharded_empty_query(self): @@ -1990,7 +1976,6 @@ async def test_read_rows_sharded_empty_query(self): assert "empty sharded_query" in str(exc.value) @CrossSync.pytest - @CrossSync.convert(replace_symbols={"TestReadRowsAsync": "TestReadRows"}) async def test_read_rows_sharded_multiple_queries(self): """ Test with multiple queries. Should return results from both @@ -2001,9 +1986,9 @@ async def test_read_rows_sharded_multiple_queries(self): table.client._gapic_client, "read_rows" ) as read_rows: read_rows.side_effect = ( - lambda *args, **kwargs: TestReadRowsAsync._make_gapic_stream( + lambda *args, **kwargs: CrossSync.TestReadRows._make_gapic_stream( [ - TestReadRowsAsync._make_chunk(row_key=k) + CrossSync.TestReadRows._make_chunk(row_key=k) for k in args[0].rows.row_keys ] ) @@ -2200,11 +2185,9 @@ async def mock_call(*args, **kwargs): @CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestSampleRowKeys") class TestSampleRowKeysAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) + @CrossSync.convert def _make_client(self, *args, **kwargs): - return TestBigtableDataClientAsync._make_client(*args, **kwargs) + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) @CrossSync.convert async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): @@ -2358,11 +2341,9 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio path="tests.unit.data._sync.test_client.TestMutateRow", ) class TestMutateRowAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) + @CrossSync.convert def _make_client(self, *args, **kwargs): - return TestBigtableDataClientAsync._make_client(*args, **kwargs) + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) @CrossSync.pytest @pytest.mark.parametrize( @@ -2539,11 +2520,9 @@ async def test_mutate_row_no_mutations(self, mutations): path="tests.unit.data._sync.test_client.TestBulkMutateRows", ) class TestBulkMutateRowsAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) + @CrossSync.convert def _make_client(self, *args, **kwargs): - return TestBigtableDataClientAsync._make_client(*args, **kwargs) + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) @CrossSync.convert async def _mock_response(self, response_list): @@ -2923,11 +2902,9 @@ async def test_bulk_mutate_error_recovery(self): @CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestCheckAndMutateRow") class TestCheckAndMutateRowAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) + @CrossSync.convert def _make_client(self, *args, **kwargs): - return TestBigtableDataClientAsync._make_client(*args, **kwargs) + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) @pytest.mark.parametrize("gapic_result", [True, False]) @CrossSync.pytest @@ -3078,11 +3055,9 @@ async def test_check_and_mutate_mutations_parsing(self): @CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadModifyWriteRow") class TestReadModifyWriteRowAsync: - @CrossSync.convert( - replace_symbols={"TestBigtableDataClientAsync": "TestBigtableDataClient"} - ) + @CrossSync.convert def _make_client(self, *args, **kwargs): - return TestBigtableDataClientAsync._make_client(*args, **kwargs) + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) @pytest.mark.parametrize( "call_rules,expected_rules", diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index df5bb0a4c..8456f2a38 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -46,7 +46,7 @@ @CrossSync.export_sync(path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl") class Test_FlowControl: @staticmethod - @CrossSync.convert(replace_symbols={"_FlowControlAsync": "_FlowControl"}) + @CrossSync.convert def _target_class(): return _FlowControlAsync @@ -317,9 +317,9 @@ async def test_add_to_flow_oversize(self): path="tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" ) class TestMutationsBatcherAsync: - @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) + @CrossSync.convert def _get_target_class(self): - return MutationsBatcherAsync + return CrossSync.MutationsBatcher def _make_one(self, table=None, **kwargs): from google.api_core.exceptions import DeadlineExceeded @@ -479,7 +479,7 @@ async def test_ctor_invalid_values(self): self._make_one(batch_attempt_timeout=-1) assert "attempt_timeout must be greater than 0" in str(e.value) - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) + @CrossSync.convert def test_default_argument_consistency(self): """ We supply default arguments in MutationsBatcherAsync.__init__, and in @@ -489,7 +489,7 @@ def test_default_argument_consistency(self): import inspect get_batcher_signature = dict( - inspect.signature(TableAsync.mutations_batcher).parameters + inspect.signature(CrossSync.Table.mutations_batcher).parameters ) get_batcher_signature.pop("self") batcher_init_signature = dict( @@ -1201,7 +1201,7 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): ([4], [core_exceptions.DeadlineExceeded]), ], ) - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) + @CrossSync.convert async def test_customizable_retryable_errors( self, input_retryables, expected_retryables ): @@ -1215,7 +1215,7 @@ async def test_customizable_retryable_errors( with mock.patch.object(CrossSync, "retry_target") as retry_fn_mock: table = None with mock.patch("asyncio.create_task"): - table = TableAsync(mock.Mock(), "instance", "table") + table = CrossSync.Table(mock.Mock(), "instance", "table") async with self._make_one( table, batch_retryable_errors=input_retryables ) as instance: diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 6dc4ed79c..95bc95e6d 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -44,18 +44,14 @@ ) class TestReadRowsAcceptanceAsync: @staticmethod - @CrossSync.convert( - replace_symbols={"_ReadRowsOperationAsync": "_ReadRowsOperation"} - ) + @CrossSync.convert def _get_operation_class(): - return _ReadRowsOperationAsync + return CrossSync._ReadRowsOperation @staticmethod - @CrossSync.convert( - replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} - ) + @CrossSync.convert def _get_client_class(): - return BigtableDataClientAsync + return CrossSync.DataClient def parse_readrows_acceptance_tests(): dirname = os.path.dirname(__file__) From 021fde2f75ec2cf9e3e82744a237708e666f8589 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 17:36:50 -0700 Subject: [PATCH 176/360] fixed lint issues --- .../bigtable/data/_async/_mutate_rows.py | 2 +- .../cloud/bigtable/data/_async/_read_rows.py | 10 +-- google/cloud/bigtable/data/_async/client.py | 25 ++++---- .../bigtable/data/_async/mutations_batcher.py | 5 +- .../cloud/bigtable/data/_sync/cross_sync.py | 18 +++--- tests/system/data/test_system_async.py | 8 +-- tests/unit/data/_async/test__read_rows.py | 12 ---- tests/unit/data/_async/test_client.py | 64 ++++++------------- .../data/_async/test_mutations_batcher.py | 32 +++------- .../data/_async/test_read_rows_acceptance.py | 9 --- 10 files changed, 60 insertions(+), 125 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index e320a44ea..87f9c25d4 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -34,10 +34,10 @@ from google.cloud.bigtable.data.mutations import RowMutationEntry if CrossSync.is_async: - from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable_v2.services.bigtable.async_client import ( BigtableAsyncClient, ) + CrossSync.add_mapping("GapicClient", BigtableAsyncClient) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index b78b3160e..7e5c5893e 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -15,10 +15,7 @@ from __future__ import annotations -from typing import ( - TYPE_CHECKING, - Sequence, -) +from typing import Sequence from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB @@ -40,11 +37,6 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -if TYPE_CHECKING: - if CrossSync.is_async: - from google.cloud.bigtable.data._async.client import TableAsync - - @CrossSync.export_sync( path="google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", add_mapping_for_name="_ReadRowsOperation", diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 2d1e3b629..66ec7a646 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -25,7 +25,6 @@ TYPE_CHECKING, ) -import asyncio import time import warnings import random @@ -77,11 +76,6 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync if CrossSync.is_async: - from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( PooledBigtableGrpcAsyncIOTransport, ) @@ -91,10 +85,19 @@ from google.cloud.bigtable_v2.services.bigtable.async_client import ( BigtableAsyncClient, ) + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) + # define file-specific cross-sync replacements CrossSync.add_mapping("GapicClient", BigtableAsyncClient) CrossSync.add_mapping("PooledTransport", PooledBigtableGrpcAsyncIOTransport) CrossSync.add_mapping("PooledChannel", AsyncPooledChannel) + CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) + CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) + CrossSync.add_mapping("MutationsBatcher", MutationsBatcherAsync) if TYPE_CHECKING: @@ -175,9 +178,7 @@ def __init__( client_info=client_info, ) self._is_closed = CrossSync.Event() - self.transport = cast( - CrossSync.PooledTransport, self._gapic_client.transport - ) + self.transport = cast(CrossSync.PooledTransport, self._gapic_client.transport) # keep track of active instances to for warmup on channel refresh self._active_instances: Set[_WarmedInstanceKey] = set() # keep track of table objects associated with each instance @@ -479,7 +480,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) -@CrossSync.export_sync(path="google.cloud.bigtable.data._sync.client.Table", add_mapping_for_name="Table") +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync.client.Table", add_mapping_for_name="Table" +) class TableAsync: """ Main Data API surface @@ -1020,7 +1023,7 @@ def mutations_batcher( Returns: MutationsBatcherAsync: a MutationsBatcherAsync context manager that can batch requests """ - return MutationsBatcherAsync( + return CrossSync.MutationsBatcher( self, flush_interval=flush_interval, flush_limit_mutation_count=flush_limit_mutation_count, diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 3a49ef051..d2d77d3a1 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -34,9 +34,6 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -if CrossSync.is_async: - from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync - if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry @@ -46,7 +43,7 @@ @CrossSync.export_sync( path="google.cloud.bigtable.data._sync.mutations_batcher._FlowControl", - add_mapping_for_name="_FlowControl" + add_mapping_for_name="_FlowControl", ) class _FlowControlAsync: """ diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 194eb7524..92b7963c8 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -20,7 +20,6 @@ Callable, Coroutine, Sequence, - Union, AsyncIterable, AsyncIterator, AsyncGenerator, @@ -32,9 +31,6 @@ import sys import concurrent.futures import google.api_core.retry as retries -import time -import threading -import queue if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -80,11 +76,13 @@ def export_sync_impl(*args, **kwargs): register the name as a CrossSync attribute """ new_mapping = kwargs.pop("add_mapping_for_name", None) + def decorator(cls): if new_mapping: # add class to mappings if requested CrossSync.add_mapping(new_mapping, cls) return cls + return decorator @@ -96,7 +94,7 @@ class AstDecorator: but act as no-ops when encountered in live code Args: - attr_name: name of the attribute to attach to the CrossSync class + attr_name: name of the attribute to attach to the CrossSync class e.g. pytest for CrossSync.pytest required_keywords: list of required keyword arguments for the decorator. If the decorator is used without these arguments, a ValueError is @@ -251,7 +249,8 @@ def add_mapping(cls, name, value): # list of decorators that can be applied to classes and methods to guide code generation _decorators: list[AstDecorator] = [ - AstDecorator("export_sync", # decorate classes to convert + AstDecorator( + "export_sync", # decorate classes to convert required_keywords=["path"], # otput path for generated sync class async_impl=export_sync_impl, # apply this decorator to the function at runtime replace_symbols={}, # replace specific symbols across entire class @@ -259,11 +258,14 @@ def add_mapping(cls, name, value): include_file_imports=True, # when True, import statements from top of file will be included in output file add_mapping_for_name=None, # add a new attribute to CrossSync class with the given name ), - AstDecorator("convert", # decorate methods to convert from async to sync + AstDecorator( + "convert", # decorate methods to convert from async to sync sync_name=None, # use a new name for the sync class replace_symbols={}, # replace specific symbols within the function ), - AstDecorator("drop_method"), # decorate methods to drop in sync version of class + AstDecorator( + "drop_method" + ), # decorate methods to drop in sync version of class AstDecorator( "pytest", async_impl=pytest_mark_asyncio ), # decorate test methods to run with pytest-asyncio diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 2822c1b6c..d12936305 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -26,11 +26,11 @@ from . import TEST_FAMILY, TEST_FAMILY_2 -if CrossSync.is_async: - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - -@CrossSync.export_sync(path="tests.system.data.test_system.TempRowBuilder", add_mapping_for_name="TempRowBuilder") +@CrossSync.export_sync( + path="tests.system.data.test_system.TempRowBuilder", + add_mapping_for_name="TempRowBuilder", +) class TempRowBuilderAsync: """ Used to add rows to a table for testing purposes. diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 0ec108df9..896c17879 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -15,24 +15,12 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -if CrossSync.is_async: - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync -else: - from google.cloud.bigtable.data._sync._read_rows import ( # noqa: F401 - _ReadRowsOperation, - ) - # try/except added for compatibility with python < 3.8 try: from unittest import mock except ImportError: # pragma: NO COVER import mock # type: ignore -TEST_FAMILY = "family_name" -TEST_QUALIFIER = b"qualifier" -TEST_TIMESTAMP = 123456789 -TEST_LABELS = ["label1", "label2"] - @CrossSync.export_sync( path="tests.unit.data._sync.test__read_rows.TestReadRowsOperation", diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 2aede657e..b51987c5d 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -42,40 +42,9 @@ if CrossSync.is_async: from google.api_core import grpc_helpers_async - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel as PooledChannelAsync, - ) - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - from google.cloud.bigtable.data._async.client import ( - TableAsync, - BigtableDataClientAsync, - ) + from google.cloud.bigtable.data._async.client import TableAsync + CrossSync.add_mapping("grpc_helpers", grpc_helpers_async) -else: - from google.api_core import grpc_helpers # noqa: F401 - from google.cloud.bigtable_v2.services.bigtable.client import ( # noqa: F401 - BigtableClient, - ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( # noqa: F401 - PooledBigtableGrpcTransport, - ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( # noqa: F401 - PooledChannel, - ) - from google.cloud.bigtable.data._sync._read_rows import ( # noqa: F401 - _ReadRowsOperation, - ) - from google.cloud.bigtable.data._sync.client import ( # noqa: F401 - Table, - BigtableDataClient, - ) - CrossSync.add_mapping("grpc_helpers", grpc_helpers) @CrossSync.export_sync( @@ -139,7 +108,9 @@ async def test_ctor_super_inits(self): options_parsed = client_options_lib.from_dict(client_options) asyncio_portion = "-async" if CrossSync.is_async else "" transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" - with mock.patch.object(CrossSync.GapicClient, "__init__") as bigtable_client_init: + with mock.patch.object( + CrossSync.GapicClient, "__init__" + ) as bigtable_client_init: bigtable_client_init.return_value = None with mock.patch.object( ClientWithProject, "__init__" @@ -173,7 +144,9 @@ async def test_ctor_dict_options(self): from google.api_core.client_options import ClientOptions client_options = {"api_endpoint": "foo.bar:1234"} - with mock.patch.object(CrossSync.GapicClient, "__init__") as bigtable_client_init: + with mock.patch.object( + CrossSync.GapicClient, "__init__" + ) as bigtable_client_init: try: self._make_client(client_options=client_options) except TypeError: @@ -1135,7 +1108,9 @@ def test_client_ctor_sync(self): assert client._channel_refresh_tasks == [] -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestTable", add_mapping_for_name="TestTable") +@CrossSync.export_sync( + path="tests.unit.data._sync.test_client.TestTable", add_mapping_for_name="TestTable" +) class TestTableAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -1447,7 +1422,10 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ assert "app_profile_id=" not in goog_metadata -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRows", add_mapping_for_name="TestReadRows") +@CrossSync.export_sync( + path="tests.unit.data._sync.test_client.TestReadRows", + add_mapping_for_name="TestReadRows", +) class TestReadRowsAsync: """ Tests for table.read_rows and related methods. @@ -1985,13 +1963,11 @@ async def test_read_rows_sharded_multiple_queries(self): with mock.patch.object( table.client._gapic_client, "read_rows" ) as read_rows: - read_rows.side_effect = ( - lambda *args, **kwargs: CrossSync.TestReadRows._make_gapic_stream( - [ - CrossSync.TestReadRows._make_chunk(row_key=k) - for k in args[0].rows.row_keys - ] - ) + read_rows.side_effect = lambda *args, **kwargs: CrossSync.TestReadRows._make_gapic_stream( + [ + CrossSync.TestReadRows._make_chunk(row_key=k) + for k in args[0].rows.row_keys + ] ) query_1 = ReadRowsQuery(b"test_1") query_2 = ReadRowsQuery(b"test_2") diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 8456f2a38..32121d02b 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -22,20 +22,6 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -if CrossSync.is_async: - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.mutations_batcher import ( - _FlowControlAsync, - MutationsBatcherAsync, - ) -else: - from google.cloud.bigtable.data._sync.client import Table # noqa: F401 - from google.cloud.bigtable.data._sync.mutations_batcher import ( # noqa: F401 - _FlowControl, - MutationsBatcher, - ) - - # try/except added for compatibility with python < 3.8 try: from unittest import mock @@ -43,12 +29,14 @@ import mock # type: ignore -@CrossSync.export_sync(path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl") +@CrossSync.export_sync( + path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl" +) class Test_FlowControl: @staticmethod @CrossSync.convert def _target_class(): - return _FlowControlAsync + return CrossSync._FlowControl def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): return self._target_class()(max_mutation_count, max_mutation_bytes) @@ -967,12 +955,8 @@ async def test__execute_mutate_rows_returns_errors(self): FailedMutationEntryError, ) - if CrossSync.is_async: - mutate_path = "_async.mutations_batcher._MutateRowsOperationAsync" - else: - mutate_path = "_sync.mutations_batcher._MutateRowsOperation" - with mock.patch( - f"google.cloud.bigtable.data.{mutate_path}.start" + with mock.patch.object( + CrossSync._MutateRowsOperation, "start" ) as mutate_rows: err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) @@ -1101,7 +1085,9 @@ async def test_timeout_args_passed(self): batch_operation_timeout and batch_attempt_timeout should be used in api calls """ - with mock.patch.object(CrossSync, "_MutateRowsOperation", return_value=CrossSync.Mock()) as mutate_rows: + with mock.patch.object( + CrossSync, "_MutateRowsOperation", return_value=CrossSync.Mock() + ) as mutate_rows: expected_operation_timeout = 17 expected_attempt_timeout = 13 async with self._make_one( diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 95bc95e6d..b30f7544f 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -29,15 +29,6 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -if CrossSync.is_async: - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync -else: - from google.cloud.bigtable.data._sync._read_rows import ( # noqa: F401 - _ReadRowsOperation, - ) - from google.cloud.bigtable.data._sync.client import BigtableDataClient # noqa: F401 - @CrossSync.export_sync( path="tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", From de0fd9079abf85ba48ae3771b95031ef079b36f1 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 12 Jul 2024 17:41:22 -0700 Subject: [PATCH 177/360] added missing imports --- google/cloud/bigtable/data/_sync/cross_sync.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index d935dee9e..9a7428378 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -20,6 +20,7 @@ Callable, Coroutine, Sequence, + Union, AsyncIterable, AsyncIterator, AsyncGenerator, @@ -31,6 +32,9 @@ import sys import concurrent.futures import google.api_core.retry as retries +import queue +import threading +import time if TYPE_CHECKING: from typing_extensions import TypeAlias From 6ca3dddb4f3817fa3636743edcce36f6a8086b61 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Sat, 13 Jul 2024 10:04:14 -0700 Subject: [PATCH 178/360] get mappings working for sync implementation --- .cross_sync/transformers.py | 8 +++++++ .../cloud/bigtable/data/_sync/cross_sync.py | 23 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index f3a11c249..1f2e54768 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -227,6 +227,14 @@ def visit_ClassDef(self, node): if sync_cls_name not in output_artifact.contained_classes: converted = self._transform_class(node, sync_cls_name, **kwargs) output_artifact.converted_classes.append(converted) + # add mapping decorator if specified + mapping_name = kwargs.get("add_mapping_for_name") + if mapping_name: + mapping_decorator = ast.Call( + func=ast.Attribute(value=ast.Name(id='CrossSync._Sync_Impl', ctx=ast.Load()), attr='add_mapping_decorator', ctx=ast.Load()), + args=[ast.Str(s=mapping_name)], keywords=[] + ) + converted.decorator_list.append(mapping_decorator) # handle file-level mypy ignores mypy_ignores = [ s diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 9a7428378..488ed5209 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -433,6 +433,29 @@ class _Sync_Impl: Iterator: TypeAlias = typing.Iterator Generator: TypeAlias = typing.Generator + _runtime_replacements: set[Any] = set() + + @classmethod + def add_mapping_decorator(cls, name): + def decorator(wrapped_cls): + cls.add_mapping(name, wrapped_cls) + return wrapped_cls + return decorator + + @classmethod + def add_mapping(cls, name, value): + """ + Add a new attribute to the CrossSync class, for replacing library-level symbols + + Raises: + - AttributeError if the attribute already exists with a different value + """ + if not hasattr(cls, name): + cls._runtime_replacements.add(name) + elif value != getattr(cls, name): + raise AttributeError(f"Conflicting assignments for CrossSync.{name}") + setattr(cls, name, value) + @classmethod def Mock(cls, *args, **kwargs): # try/except added for compatibility with python < 3.8 From b2cf937ee1904c5b84cc080c9409c46b426ff6e7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 15 Jul 2024 11:12:18 -0700 Subject: [PATCH 179/360] moved decorators into new file --- .cross_sync/transformers.py | 100 ++++---- .../cloud/bigtable/data/_sync/cross_sync.py | 173 +------------- .../data/_sync/cross_sync_decorators.py | 218 ++++++++++++++++++ 3 files changed, 275 insertions(+), 216 deletions(-) create mode 100644 google/cloud/bigtable/data/_sync/cross_sync_decorators.py diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 1f2e54768..a76a0d187 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -16,6 +16,7 @@ import ast from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._sync.cross_sync_decorators import AstDecorator, ExportSyncDecorator from generate import CrossSyncOutputFile @@ -142,30 +143,15 @@ def visit_AsyncFunctionDef(self, node): if hasattr(node, "decorator_list"): found_list, node.decorator_list = node.decorator_list, [] for decorator in found_list: - if decorator == CrossSync.convert: - # convert async to sync - kwargs = CrossSync.convert.parse_ast_keywords(decorator) - node = AsyncToSync().visit(node) - # replace method name if specified - if kwargs["sync_name"] is not None: - node.name = kwargs["sync_name"] - # replace symbols if specified - if kwargs["replace_symbols"]: - node = SymbolReplacer(kwargs["replace_symbols"]).visit(node) - elif decorator == CrossSync.drop_method: - # drop method entirely from class - return None - elif decorator == CrossSync.pytest: - # also convert pytest methods to sync - node = AsyncToSync().visit(node) - elif decorator == CrossSync.pytest_fixture: - # add pytest.fixture decorator - decorator.func.value = ast.Name(id="pytest", ctx=ast.Load()) - decorator.func.attr = "fixture" - node.decorator_list.append(decorator) - else: + try: + handler = AstDecorator.get_for_node(decorator) + node = handler.sync_ast_transform(decorator, node, globals()) + if node is None: + return None + except ValueError: # keep unknown decorators node.decorator_list.append(decorator) + continue return node except ValueError as e: raise ValueError(f"node {node.name} failed") from e @@ -213,39 +199,43 @@ def visit_ClassDef(self, node): """ try: for decorator in node.decorator_list: - if decorator == CrossSync.export_sync: - kwargs = CrossSync.export_sync.parse_ast_keywords(decorator) - # find the path to write the sync class to - sync_path = kwargs["path"] - out_file = "/".join(sync_path.rsplit(".")[:-1]) + ".py" - sync_cls_name = sync_path.rsplit(".", 1)[-1] - # find the artifact file for the save location - output_artifact = self._artifact_dict.get( - out_file, CrossSyncOutputFile(out_file) - ) - # write converted class details if not already present - if sync_cls_name not in output_artifact.contained_classes: - converted = self._transform_class(node, sync_cls_name, **kwargs) - output_artifact.converted_classes.append(converted) - # add mapping decorator if specified - mapping_name = kwargs.get("add_mapping_for_name") - if mapping_name: - mapping_decorator = ast.Call( - func=ast.Attribute(value=ast.Name(id='CrossSync._Sync_Impl', ctx=ast.Load()), attr='add_mapping_decorator', ctx=ast.Load()), - args=[ast.Str(s=mapping_name)], keywords=[] - ) - converted.decorator_list.append(mapping_decorator) - # handle file-level mypy ignores - mypy_ignores = [ - s - for s in kwargs["mypy_ignore"] - if s not in output_artifact.mypy_ignore - ] - output_artifact.mypy_ignore.extend(mypy_ignores) - # handle file-level imports - if not output_artifact.imports and kwargs["include_file_imports"]: - output_artifact.imports = self.imports - self._artifact_dict[out_file] = output_artifact + try: + handler = AstDecorator.get_for_node(decorator) + if handler == ExportSyncDecorator: + kwargs = CrossSync.export_sync.parse_ast_keywords(decorator) + # find the path to write the sync class to + sync_path = kwargs["path"] + out_file = "/".join(sync_path.rsplit(".")[:-1]) + ".py" + sync_cls_name = sync_path.rsplit(".", 1)[-1] + # find the artifact file for the save location + output_artifact = self._artifact_dict.get( + out_file, CrossSyncOutputFile(out_file) + ) + # write converted class details if not already present + if sync_cls_name not in output_artifact.contained_classes: + converted = self._transform_class(node, sync_cls_name, **kwargs) + output_artifact.converted_classes.append(converted) + # add mapping decorator if specified + mapping_name = kwargs.get("add_mapping_for_name") + if mapping_name: + mapping_decorator = ast.Call( + func=ast.Attribute(value=ast.Name(id='CrossSync._Sync_Impl', ctx=ast.Load()), attr='add_mapping_decorator', ctx=ast.Load()), + args=[ast.Str(s=mapping_name)], keywords=[] + ) + converted.decorator_list.append(mapping_decorator) + # handle file-level mypy ignores + mypy_ignores = [ + s + for s in kwargs["mypy_ignore"] + if s not in output_artifact.mypy_ignore + ] + output_artifact.mypy_ignore.extend(mypy_ignores) + # handle file-level imports + if not output_artifact.imports and kwargs["include_file_imports"]: + output_artifact.imports = self.imports + self._artifact_dict[out_file] = output_artifact + except ValueError: + continue return node except ValueError as e: raise ValueError(f"failed for class: {node.name}") from e diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 488ed5209..140a4c948 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -35,6 +35,7 @@ import queue import threading import time +from .cross_sync_decorators import AstDecorator, ExportSyncDecorator, ConvertDecorator, DropMethodDecorator, PytestDecorator, PytestFixtureDecorator if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -90,132 +91,7 @@ def decorator(cls): return decorator -class AstDecorator: - """ - Helper class for CrossSync decorators used for guiding ast transformations. - - These decorators provide arguments that are used during the code generation process, - but act as no-ops when encountered in live code - - Args: - attr_name: name of the attribute to attach to the CrossSync class - e.g. pytest for CrossSync.pytest - required_keywords: list of required keyword arguments for the decorator. - If the decorator is used without these arguments, a ValueError is - raised during code generation - async_impl: If given, the async code will apply this decorator to its - wrapped function at runtime. If not given, the decorator will be a no-op - **default_kwargs: any kwargs passed define the valid arguments when using the decorator. - The value of each kwarg is the default value for the argument. - """ - - def __init__( - self, - attr_name, - required_keywords=(), - async_impl=None, - **default_kwargs, - ): - self.name = attr_name - self.required_kwargs = required_keywords - self.default_kwargs = default_kwargs - self.all_valid_keys = [*required_keywords, *default_kwargs.keys()] - self.async_impl = async_impl - - def __call__(self, *args, **kwargs): - """ - Called when the decorator is used in code. - - Returns a no-op decorator function, or applies the async_impl decorator - """ - # raise error if invalid kwargs are passed - for kwarg in kwargs: - if kwarg not in self.all_valid_keys: - raise ValueError(f"Invalid keyword argument: {kwarg}") - # if async_impl is provided, use the given decorator function - if self.async_impl: - return self.async_impl(*args, **{**self.default_kwargs, **kwargs}) - # if no arguments, args[0] will hold the function to be decorated - # return the function as is - if len(args) == 1 and callable(args[0]): - return args[0] - - # if arguments are provided, return a no-op decorator function - def decorator(func): - return func - - return decorator - - def parse_ast_keywords(self, node): - """ - When this decorator is encountered in the ast during sync generation, parse the - keyword arguments back from ast nodes to python primitives - - Return a full set of kwargs, using default values for missing arguments - """ - got_kwargs = ( - {kw.arg: self._convert_ast_to_py(kw.value) for kw in node.keywords} - if hasattr(node, "keywords") - else {} - ) - for key in got_kwargs.keys(): - if key not in self.all_valid_keys: - raise ValueError(f"Invalid keyword argument: {key}") - for key in self.required_kwargs: - if key not in got_kwargs: - raise ValueError(f"Missing required keyword argument: {key}") - return {**self.default_kwargs, **got_kwargs} - - def _convert_ast_to_py(self, ast_node): - """ - Helper to convert ast primitives to python primitives. Used when unwrapping kwargs - """ - import ast - - if isinstance(ast_node, ast.Constant): - return ast_node.value - if isinstance(ast_node, ast.List): - return [self._convert_ast_to_py(node) for node in ast_node.elts] - if isinstance(ast_node, ast.Dict): - return { - self._convert_ast_to_py(k): self._convert_ast_to_py(v) - for k, v in zip(ast_node.keys, ast_node.values) - } - raise ValueError(f"Unsupported type {type(ast_node)}") - - def _node_eq(self, node): - """ - Check if the given ast node is a call to this decorator - """ - import ast - - if "CrossSync" in ast.dump(node): - decorator_type = node.func.attr if hasattr(node, "func") else node.attr - if decorator_type == self.name: - return True - return False - - def __eq__(self, other): - """ - Helper to support == comparison with ast nodes - """ - return self._node_eq(other) - - -class _DecoratorMeta(type): - """ - Metaclass to attach AstDecorator objects in internal self._decorators - as attributes - """ - - def __getattr__(self, name): - for decorator in self._decorators: - if name == decorator.name: - return decorator - raise AttributeError(f"CrossSync has no attribute {name}") - - -class CrossSync(metaclass=_DecoratorMeta): +class CrossSync: # support CrossSync.is_async to check if the current environment is async is_async = True @@ -237,6 +113,16 @@ class CrossSync(metaclass=_DecoratorMeta): Iterator: TypeAlias = AsyncIterator Generator: TypeAlias = AsyncGenerator + # decorators + export_sync = ExportSyncDecorator() # decorate classes to convert + convert = ConvertDecorator() # decorate methods to convert from async to sync + drop_method = DropMethodDecorator() # decorate methods to remove from sync version + pytest = PytestDecorator() # decorate test methods to run with pytest-asyncio + pytest_fixture = PytestFixtureDecorator() # decorate test methods to run with pytest fixture + + # list of attributes that can be added to the CrossSync class at runtime + _runtime_replacements: set[Any] = set() + @classmethod def add_mapping(cls, name, value): """ @@ -251,41 +137,6 @@ def add_mapping(cls, name, value): raise AttributeError(f"Conflicting assignments for CrossSync.{name}") setattr(cls, name, value) - # list of decorators that can be applied to classes and methods to guide code generation - _decorators: list[AstDecorator] = [ - AstDecorator( - "export_sync", # decorate classes to convert - required_keywords=["path"], # otput path for generated sync class - async_impl=export_sync_impl, # apply this decorator to the function at runtime - replace_symbols={}, # replace specific symbols across entire class - mypy_ignore=(), # set of mypy error codes to ignore in output file - include_file_imports=True, # when True, import statements from top of file will be included in output file - add_mapping_for_name=None, # add a new attribute to CrossSync class with the given name - ), - AstDecorator( - "convert", # decorate methods to convert from async to sync - sync_name=None, # use a new name for the sync class - replace_symbols={}, # replace specific symbols within the function - ), - AstDecorator( - "drop_method" - ), # decorate methods to drop in sync version of class - AstDecorator( - "pytest", async_impl=pytest_mark_asyncio - ), # decorate test methods to run with pytest-asyncio - AstDecorator( - "pytest_fixture", # decorate test methods to run with pytest fixture - async_impl=pytest_asyncio_fixture, - scope="function", - params=None, - autouse=False, - ids=None, - name=None, - ), - ] - # list of attributes that can be added to the CrossSync class at runtime - _runtime_replacements: set[Any] = set() - @classmethod def Mock(cls, *args, **kwargs): """ diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py new file mode 100644 index 000000000..5bcbe8c5e --- /dev/null +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -0,0 +1,218 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class AstDecorator: + """ + Helper class for CrossSync decorators used for guiding ast transformations. + + These decorators provide arguments that are used during the code generation process, + but act as no-ops when encountered in live code + + Args: + attr_name: name of the attribute to attach to the CrossSync class + e.g. pytest for CrossSync.pytest + required_keywords: list of required keyword arguments for the decorator. + If the decorator is used without these arguments, a ValueError is + raised during code generation + async_impl: If given, the async code will apply this decorator to its + wrapped function at runtime. If not given, the decorator will be a no-op + **default_kwargs: any kwargs passed define the valid arguments when using the decorator. + The value of each kwarg is the default value for the argument. + """ + + name = None + required_kwargs = () + default_kwargs = {} + + @classmethod + def all_valid_keys(cls): + return [*cls.required_kwargs, *cls.default_kwargs.keys()] + + def __call__(self, *args, **kwargs): + """ + Called when the decorator is used in code. + + Returns a no-op decorator function, or applies the async_impl decorator + """ + # raise error if invalid kwargs are passed + for kwarg in kwargs: + if kwarg not in self.all_valid_keys(): + raise ValueError(f"Invalid keyword argument: {kwarg}") + return self.async_decorator(*args, **kwargs) + + @classmethod + def parse_ast_keywords(cls, node): + """ + When this decorator is encountered in the ast during sync generation, parse the + keyword arguments back from ast nodes to python primitives + + Return a full set of kwargs, using default values for missing arguments + """ + got_kwargs = ( + {kw.arg: cls._convert_ast_to_py(kw.value) for kw in node.keywords} + if hasattr(node, "keywords") + else {} + ) + for key in got_kwargs.keys(): + if key not in cls.all_valid_keys(): + raise ValueError(f"Invalid keyword argument: {key}") + for key in cls.required_kwargs: + if key not in got_kwargs: + raise ValueError(f"Missing required keyword argument: {key}") + return {**cls.default_kwargs, **got_kwargs} + + @classmethod + def _convert_ast_to_py(cls, ast_node): + """ + Helper to convert ast primitives to python primitives. Used when unwrapping kwargs + """ + import ast + + if isinstance(ast_node, ast.Constant): + return ast_node.value + if isinstance(ast_node, ast.List): + return [cls._convert_ast_to_py(node) for node in ast_node.elts] + if isinstance(ast_node, ast.Dict): + return { + cls._convert_ast_to_py(k): cls._convert_ast_to_py(v) + for k, v in zip(ast_node.keys, ast_node.values) + } + raise ValueError(f"Unsupported type {type(ast_node)}") + + @classmethod + def async_decorator(cls, *args, **kwargs): + """ + Decorator to apply the async_impl decorator to the wrapped function + + Default implementation is a no-op + """ + # if no arguments, args[0] will hold the function to be decorated + # return the function as is + if len(args) == 1 and callable(args[0]): + return args[0] + + # if arguments are provided, return a no-op decorator function + def decorator(func): + return func + + return decorator + + @classmethod + def sync_ast_transform(cls, decorator, wrapped_node, transformers): + """ + When this decorator is encountered in the ast during sync generation, + apply this behavior + + Defaults to no-op + """ + return wrapped_node + + @classmethod + def get_for_node(cls, node): + import ast + if "CrossSync" in ast.dump(node): + decorator_name = node.func.attr if hasattr(node, "func") else node.attr + for subclass in cls.__subclasses__(): + if subclass.name == decorator_name: + return subclass + raise ValueError(f"Unknown decorator encountered") + + +class ExportSyncDecorator(AstDecorator): + + name = "export_sync" + + required_kwargs = ("path",) + default_kwargs = { + "replace_symbols": {}, # replace symbols in the generated sync class + "mypy_ignore": (), # set of mypy errors to ignore + "include_file_imports": True, # include imports from the file in the generated sync class + "add_mapping_for_name": None, # add a new attribute to CrossSync with the given name + } + + @classmethod + def async_decorator(cls, *args, **kwargs): + from .cross_sync import CrossSync + new_mapping = kwargs.get("add_mapping_for_name") + def decorator(cls): + if new_mapping: + CrossSync.add_mapping(new_mapping, cls) + return cls + return decorator + +class ConvertDecorator(AstDecorator): + + name = "convert" + + default_kwargs = { + "sync_name": None, # use a new name for the sync method + "replace_symbols": {}, # replace symbols in the generated sync method + } + + @classmethod + def sync_ast_transform(cls, decorator, wrapped_node, transformers): + kwargs = cls.parse_ast_keywords(decorator) + if kwargs["sync_name"]: + wrapped_node.name = kwargs["sync_name"] + if kwargs["replace_symbols"]: + replacer = transformers["SymbolReplacer"] + wrapped_node = replacer(kwargs["replace_symbols"]).visit(wrapped_node) + return wrapped_node + + +class DropMethodDecorator(AstDecorator): + + name = "drop_method" + + @classmethod + def sync_ast_transform(cls, decorator, wrapped_node, transformers): + return None + +class PytestDecorator(AstDecorator): + + name = "pytest" + + @classmethod + def async_decorator(cls, *args, **kwargs): + import pytest + return pytest.mark.asyncio + +class PytestFixtureDecorator(AstDecorator): + + name = "pytest_fixture" + + # arguments passed down to pytest(_asyncio).fixture decorator + default_kwargs = { + "scope": "function", + "params": None, + "autouse": False, + "ids": None, + "name": None, + } + + @classmethod + def async_decorator(cls, *args, **kwargs): + import pytest_asyncio + def decorator(func): + return pytest_asyncio.fixture(**kwargs)(func) + return decorator + + @classmethod + def sync_ast_transform(cls, decorator, wrapped_node, transformers): + import ast + decorator.func.value = ast.Name(id="pytest", ctx=ast.Load()) + decorator.func.attr = "fixture" + wrapped_node.decorator_list.append(decorator) + return wrapped_node From 4860061d681bcf9243a108d0ecfe538545c3386b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Jul 2024 14:52:55 -0700 Subject: [PATCH 180/360] moved transformation from transformers into decorators --- .cross_sync/transformers.py | 42 +++--------------- .../data/_sync/cross_sync_decorators.py | 44 +++++++++++++++++++ 2 files changed, 51 insertions(+), 35 deletions(-) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index a76a0d187..f4784f248 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -196,6 +196,9 @@ def visit_ClassDef(self, node): """ Called for each class in file. If class has a CrossSync decorator, it will be transformed according to the decorator arguments. Otherwise, no changes will occur + + Uses a set of CrossSyncOutputFile objects to store the transformed classes + and avoid duplicate writes """ try: for decorator in node.decorator_list: @@ -213,16 +216,11 @@ def visit_ClassDef(self, node): ) # write converted class details if not already present if sync_cls_name not in output_artifact.contained_classes: - converted = self._transform_class(node, sync_cls_name, **kwargs) + # transformation is handled in sync_ast_transform method of the decorator + converted = ExportSyncDecorator.sync_ast_transform( + decorator, node, globals() + ) output_artifact.converted_classes.append(converted) - # add mapping decorator if specified - mapping_name = kwargs.get("add_mapping_for_name") - if mapping_name: - mapping_decorator = ast.Call( - func=ast.Attribute(value=ast.Name(id='CrossSync._Sync_Impl', ctx=ast.Load()), attr='add_mapping_decorator', ctx=ast.Load()), - args=[ast.Str(s=mapping_name)], keywords=[] - ) - converted.decorator_list.append(mapping_decorator) # handle file-level mypy ignores mypy_ignores = [ s @@ -240,32 +238,6 @@ def visit_ClassDef(self, node): except ValueError as e: raise ValueError(f"failed for class: {node.name}") from e - def _transform_class( - self, - cls_ast: ast.ClassDef, - new_name: str, - replace_symbols: dict[str, str] | None = None, - **kwargs, - ) -> ast.ClassDef: - """ - Transform async class into sync one, by applying the following ast transformations: - - SymbolReplacer: to replace any class-level symbols specified in CrossSync.export_sync(replace_symbols={}) decorator - - CrossSyncMethodDecoratorHandler: to visit each method in the class and apply any CrossSync decorators found - """ - # update name - cls_ast.name = new_name - # strip CrossSync decorators - if hasattr(cls_ast, "decorator_list"): - cls_ast.decorator_list = [ - d for d in cls_ast.decorator_list if "CrossSync" not in ast.dump(d) - ] - # convert class contents - cls_ast = self.cross_sync_symbol_transformer.visit(cls_ast) - if replace_symbols: - cls_ast = SymbolReplacer(replace_symbols).visit(cls_ast) - cls_ast = self.cross_sync_method_handler.visit(cls_ast) - return cls_ast - def _get_imports( self, tree: ast.Module ) -> list[ast.Import | ast.ImportFrom | ast.Try | ast.If]: diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py index 5bcbe8c5e..51b9530ef 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -152,6 +152,50 @@ def decorator(cls): return cls return decorator + @classmethod + def sync_ast_transform(cls, decorator, wrapped_node, transformers): + """ + Transform async class into sync copy + """ + import ast + import copy + kwargs = cls.parse_ast_keywords(decorator) + # copy wrapped node + wrapped_node = copy.deepcopy(wrapped_node) + # update name + sync_path = kwargs["path"] + sync_cls_name = sync_path.rsplit(".", 1)[-1] + orig_name = wrapped_node.name + wrapped_node.name = sync_cls_name + # strip CrossSync decorators + if hasattr(wrapped_node, "decorator_list"): + wrapped_node.decorator_list = [ + d for d in wrapped_node.decorator_list if "CrossSync" not in ast.dump(d) + ] + # add mapping decorator if needed + if kwargs["add_mapping_for_name"]: + wrapped_node.decorator_list.append( + ast.Call( + func=ast.Attribute( + value=ast.Name(id="CrossSync", ctx=ast.Load()), + attr="add_mapping", + ctx=ast.Load(), + ), + args=[ + ast.Constant(value=kwargs["add_mapping_for_name"]), + ], + keywords=[], + ) + ) + # convert class contents + replace_dict = kwargs["replace_symbols"] or {} + replace_dict.update({"CrossSync": f"CrossSync._SyncImpl"}) + wrapped_node = transformers["SymbolReplacer"](replace_dict).visit(wrapped_node) + # visit CrossSync method decorators + wrapped_node = transformers["CrossSyncMethodDecoratorHandler"]().visit(wrapped_node) + return wrapped_node + + class ConvertDecorator(AstDecorator): name = "convert" From 63891d7b4c9d6548b24a8818cde0cb40d16e6aa5 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Jul 2024 15:48:51 -0700 Subject: [PATCH 181/360] simplified arguments for decorators --- .../data/_sync/cross_sync_decorators.py | 170 +++++++----------- 1 file changed, 64 insertions(+), 106 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py index 51b9530ef..4c25cc1c6 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Sequence class AstDecorator: @@ -19,59 +20,21 @@ class AstDecorator: These decorators provide arguments that are used during the code generation process, but act as no-ops when encountered in live code - - Args: - attr_name: name of the attribute to attach to the CrossSync class - e.g. pytest for CrossSync.pytest - required_keywords: list of required keyword arguments for the decorator. - If the decorator is used without these arguments, a ValueError is - raised during code generation - async_impl: If given, the async code will apply this decorator to its - wrapped function at runtime. If not given, the decorator will be a no-op - **default_kwargs: any kwargs passed define the valid arguments when using the decorator. - The value of each kwarg is the default value for the argument. """ - name = None - required_kwargs = () - default_kwargs = {} - - @classmethod - def all_valid_keys(cls): - return [*cls.required_kwargs, *cls.default_kwargs.keys()] - def __call__(self, *args, **kwargs): """ Called when the decorator is used in code. Returns a no-op decorator function, or applies the async_impl decorator """ - # raise error if invalid kwargs are passed - for kwarg in kwargs: - if kwarg not in self.all_valid_keys(): - raise ValueError(f"Invalid keyword argument: {kwarg}") - return self.async_decorator(*args, **kwargs) - - @classmethod - def parse_ast_keywords(cls, node): - """ - When this decorator is encountered in the ast during sync generation, parse the - keyword arguments back from ast nodes to python primitives - - Return a full set of kwargs, using default values for missing arguments - """ - got_kwargs = ( - {kw.arg: cls._convert_ast_to_py(kw.value) for kw in node.keywords} - if hasattr(node, "keywords") - else {} - ) - for key in got_kwargs.keys(): - if key not in cls.all_valid_keys(): - raise ValueError(f"Invalid keyword argument: {key}") - for key in cls.required_kwargs: - if key not in got_kwargs: - raise ValueError(f"Missing required keyword argument: {key}") - return {**cls.default_kwargs, **got_kwargs} + new_instance = self.__class__(**kwargs) + wrapper = new_instance.async_decorator() + if len(args) == 1 and callable(args[0]): + # if decorator is used without arguments, return wrapped function directly + return wrapper(args[0]) + # otherwise, return wrap function + return wrapper @classmethod def _convert_ast_to_py(cls, ast_node): @@ -91,26 +54,18 @@ def _convert_ast_to_py(cls, ast_node): } raise ValueError(f"Unsupported type {type(ast_node)}") - @classmethod - def async_decorator(cls, *args, **kwargs): + def async_decorator(self): """ Decorator to apply the async_impl decorator to the wrapped function Default implementation is a no-op """ - # if no arguments, args[0] will hold the function to be decorated - # return the function as is - if len(args) == 1 and callable(args[0]): - return args[0] - - # if arguments are provided, return a no-op decorator function - def decorator(func): - return func + def decorator(f): + return f return decorator - @classmethod - def sync_ast_transform(cls, decorator, wrapped_node, transformers): + def sync_ast_transform(self, decorator, wrapped_node, transformers): """ When this decorator is encountered in the ast during sync generation, apply this behavior @@ -124,9 +79,14 @@ def get_for_node(cls, node): import ast if "CrossSync" in ast.dump(node): decorator_name = node.func.attr if hasattr(node, "func") else node.attr + got_kwargs = ( + {kw.arg: cls._convert_ast_to_py(kw.value) for kw in node.keywords} + if hasattr(node, "keywords") + else {} + ) for subclass in cls.__subclasses__(): if subclass.name == decorator_name: - return subclass + return subclass(**got_kwargs) raise ValueError(f"Unknown decorator encountered") @@ -134,37 +94,41 @@ class ExportSyncDecorator(AstDecorator): name = "export_sync" - required_kwargs = ("path",) - default_kwargs = { - "replace_symbols": {}, # replace symbols in the generated sync class - "mypy_ignore": (), # set of mypy errors to ignore - "include_file_imports": True, # include imports from the file in the generated sync class - "add_mapping_for_name": None, # add a new attribute to CrossSync with the given name - } - - @classmethod - def async_decorator(cls, *args, **kwargs): + def __init__( + self, + path:str = "", # path to output the generated sync class + replace_symbols:dict|None = None, # replace symbols in the generated sync class + mypy_ignore:Sequence[str] = (), # set of mypy errors to ignore + include_file_imports:bool = True, # include imports from the file in the generated sync class + add_mapping_for_name:str|None = None, # add a new attribute to CrossSync with the given name + ): + self.path = path + self.replace_symbols = replace_symbols + self.mypy_ignore = mypy_ignore + self.include_file_imports = include_file_imports + self.add_mapping_for_name = add_mapping_for_name + + def async_decorator(self): from .cross_sync import CrossSync - new_mapping = kwargs.get("add_mapping_for_name") + new_mapping = self.add_mapping_for_name def decorator(cls): if new_mapping: CrossSync.add_mapping(new_mapping, cls) return cls return decorator - @classmethod - def sync_ast_transform(cls, decorator, wrapped_node, transformers): + def sync_ast_transform(self, decorator, wrapped_node, transformers): """ Transform async class into sync copy """ import ast import copy - kwargs = cls.parse_ast_keywords(decorator) + if not self.path: + raise ValueError(f"{wrapped_node.name} has no path specified in export_sync decorator") # copy wrapped node wrapped_node = copy.deepcopy(wrapped_node) # update name - sync_path = kwargs["path"] - sync_cls_name = sync_path.rsplit(".", 1)[-1] + sync_cls_name = self.path.rsplit(".", 1)[-1] orig_name = wrapped_node.name wrapped_node.name = sync_cls_name # strip CrossSync decorators @@ -173,7 +137,7 @@ def sync_ast_transform(cls, decorator, wrapped_node, transformers): d for d in wrapped_node.decorator_list if "CrossSync" not in ast.dump(d) ] # add mapping decorator if needed - if kwargs["add_mapping_for_name"]: + if self.add_mapping_for_name: wrapped_node.decorator_list.append( ast.Call( func=ast.Attribute( @@ -182,13 +146,13 @@ def sync_ast_transform(cls, decorator, wrapped_node, transformers): ctx=ast.Load(), ), args=[ - ast.Constant(value=kwargs["add_mapping_for_name"]), + ast.Constant(value=self.add_mapping_for_name), ], keywords=[], ) ) # convert class contents - replace_dict = kwargs["replace_symbols"] or {} + replace_dict = self.replace_symbols or {} replace_dict.update({"CrossSync": f"CrossSync._SyncImpl"}) wrapped_node = transformers["SymbolReplacer"](replace_dict).visit(wrapped_node) # visit CrossSync method decorators @@ -200,19 +164,20 @@ class ConvertDecorator(AstDecorator): name = "convert" - default_kwargs = { - "sync_name": None, # use a new name for the sync method - "replace_symbols": {}, # replace symbols in the generated sync method - } - - @classmethod - def sync_ast_transform(cls, decorator, wrapped_node, transformers): - kwargs = cls.parse_ast_keywords(decorator) - if kwargs["sync_name"]: - wrapped_node.name = kwargs["sync_name"] - if kwargs["replace_symbols"]: + def __init__( + self, + sync_name:str|None = None, # use a new name for the sync method + replace_symbols:dict = {} # replace symbols in the generated sync method + ): + self.sync_name = sync_name + self.replace_symbols = replace_symbols + + def sync_ast_transform(self, decorator, wrapped_node, transformers): + if self.sync_name: + wrapped_node.name = self.sync_name + if self.replace_symbols: replacer = transformers["SymbolReplacer"] - wrapped_node = replacer(kwargs["replace_symbols"]).visit(wrapped_node) + wrapped_node = replacer(self.replace_symbols).visit(wrapped_node) return wrapped_node @@ -220,16 +185,14 @@ class DropMethodDecorator(AstDecorator): name = "drop_method" - @classmethod - def sync_ast_transform(cls, decorator, wrapped_node, transformers): + def sync_ast_transform(self, decorator, wrapped_node, transformers): return None class PytestDecorator(AstDecorator): name = "pytest" - @classmethod - def async_decorator(cls, *args, **kwargs): + def async_decorator(self): import pytest return pytest.mark.asyncio @@ -237,24 +200,19 @@ class PytestFixtureDecorator(AstDecorator): name = "pytest_fixture" - # arguments passed down to pytest(_asyncio).fixture decorator - default_kwargs = { - "scope": "function", - "params": None, - "autouse": False, - "ids": None, - "name": None, - } + def __init__( + self, + scope:str = "function", # passed to pytest.fixture + ): + self.scope = scope - @classmethod - def async_decorator(cls, *args, **kwargs): + def async_decorator(self): import pytest_asyncio def decorator(func): - return pytest_asyncio.fixture(**kwargs)(func) + return pytest_asyncio.fixture(scope=self.scope)(func) return decorator - @classmethod - def sync_ast_transform(cls, decorator, wrapped_node, transformers): + def sync_ast_transform(self, decorator, wrapped_node, transformers): import ast decorator.func.value = ast.Name(id="pytest", ctx=ast.Load()) decorator.func.attr = "fixture" From 164f5a89175cbf4f2f341096c2c38e9be9558ea3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Jul 2024 16:17:13 -0700 Subject: [PATCH 182/360] use separate decorate method --- .cross_sync/transformers.py | 16 +++--- .../cloud/bigtable/data/_sync/cross_sync.py | 12 ++--- .../data/_sync/cross_sync_decorators.py | 54 +++++++++++-------- 3 files changed, 44 insertions(+), 38 deletions(-) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index f4784f248..9a90d78d0 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -16,7 +16,7 @@ import ast from google.cloud.bigtable.data._sync.cross_sync import CrossSync -from google.cloud.bigtable.data._sync.cross_sync_decorators import AstDecorator, ExportSyncDecorator +from google.cloud.bigtable.data._sync.cross_sync_decorators import AstDecorator, ExportSync from generate import CrossSyncOutputFile @@ -204,12 +204,10 @@ def visit_ClassDef(self, node): for decorator in node.decorator_list: try: handler = AstDecorator.get_for_node(decorator) - if handler == ExportSyncDecorator: - kwargs = CrossSync.export_sync.parse_ast_keywords(decorator) + if isinstance(handler, ExportSync): # find the path to write the sync class to - sync_path = kwargs["path"] - out_file = "/".join(sync_path.rsplit(".")[:-1]) + ".py" - sync_cls_name = sync_path.rsplit(".", 1)[-1] + out_file = "/".join(handler.path.rsplit(".")[:-1]) + ".py" + sync_cls_name = handler.path.rsplit(".", 1)[-1] # find the artifact file for the save location output_artifact = self._artifact_dict.get( out_file, CrossSyncOutputFile(out_file) @@ -217,19 +215,19 @@ def visit_ClassDef(self, node): # write converted class details if not already present if sync_cls_name not in output_artifact.contained_classes: # transformation is handled in sync_ast_transform method of the decorator - converted = ExportSyncDecorator.sync_ast_transform( + converted = handler.sync_ast_transform( decorator, node, globals() ) output_artifact.converted_classes.append(converted) # handle file-level mypy ignores mypy_ignores = [ s - for s in kwargs["mypy_ignore"] + for s in handler.mypy_ignore if s not in output_artifact.mypy_ignore ] output_artifact.mypy_ignore.extend(mypy_ignores) # handle file-level imports - if not output_artifact.imports and kwargs["include_file_imports"]: + if not output_artifact.imports and handler.include_file_imports: output_artifact.imports = self.imports self._artifact_dict[out_file] = output_artifact except ValueError: diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 140a4c948..47a5c7d89 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -35,7 +35,7 @@ import queue import threading import time -from .cross_sync_decorators import AstDecorator, ExportSyncDecorator, ConvertDecorator, DropMethodDecorator, PytestDecorator, PytestFixtureDecorator +from .cross_sync_decorators import AstDecorator, ExportSync, Convert, DropMethod, Pytest, PytestFixture if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -114,11 +114,11 @@ class CrossSync: Generator: TypeAlias = AsyncGenerator # decorators - export_sync = ExportSyncDecorator() # decorate classes to convert - convert = ConvertDecorator() # decorate methods to convert from async to sync - drop_method = DropMethodDecorator() # decorate methods to remove from sync version - pytest = PytestDecorator() # decorate test methods to run with pytest-asyncio - pytest_fixture = PytestFixtureDecorator() # decorate test methods to run with pytest fixture + export_sync = ExportSync.decorator # decorate classes to convert + convert = Convert.decorator # decorate methods to convert from async to sync + drop_method = DropMethod.decorator # decorate methods to remove from sync version + pytest = Pytest.decorator # decorate test methods to run with pytest-asyncio + pytest_fixture = PytestFixture.decorator # decorate test methods to run with pytest fixture # list of attributes that can be added to the CrossSync class at runtime _runtime_replacements: set[Any] = set() diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py index 4c25cc1c6..e78cf0b5d 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -22,17 +22,24 @@ class AstDecorator: but act as no-ops when encountered in live code """ - def __call__(self, *args, **kwargs): + @classmethod + def decorator(cls, *args, **kwargs): """ Called when the decorator is used in code. Returns a no-op decorator function, or applies the async_impl decorator """ - new_instance = self.__class__(**kwargs) - wrapper = new_instance.async_decorator() + # check for decorators with no arguments + func = None if len(args) == 1 and callable(args[0]): - # if decorator is used without arguments, return wrapped function directly - return wrapper(args[0]) + func = args[0] + args = args[1:] + # create new instance from given arguments + new_instance = cls(*args, **kwargs) + wrapper = new_instance.async_decorator() + # if we can, return single wrapped function + if func: + return wrapper(func) # otherwise, return wrap function return wrapper @@ -90,13 +97,14 @@ def get_for_node(cls, node): raise ValueError(f"Unknown decorator encountered") -class ExportSyncDecorator(AstDecorator): +class ExportSync(AstDecorator): name = "export_sync" def __init__( self, - path:str = "", # path to output the generated sync class + path:str, # path to output the generated sync class + *, replace_symbols:dict|None = None, # replace symbols in the generated sync class mypy_ignore:Sequence[str] = (), # set of mypy errors to ignore include_file_imports:bool = True, # include imports from the file in the generated sync class @@ -160,12 +168,13 @@ def sync_ast_transform(self, decorator, wrapped_node, transformers): return wrapped_node -class ConvertDecorator(AstDecorator): +class Convert(AstDecorator): name = "convert" def __init__( self, + *, sync_name:str|None = None, # use a new name for the sync method replace_symbols:dict = {} # replace symbols in the generated sync method ): @@ -181,14 +190,14 @@ def sync_ast_transform(self, decorator, wrapped_node, transformers): return wrapped_node -class DropMethodDecorator(AstDecorator): +class DropMethod(AstDecorator): name = "drop_method" def sync_ast_transform(self, decorator, wrapped_node, transformers): return None -class PytestDecorator(AstDecorator): +class Pytest(AstDecorator): name = "pytest" @@ -196,25 +205,24 @@ def async_decorator(self): import pytest return pytest.mark.asyncio -class PytestFixtureDecorator(AstDecorator): +class PytestFixture(AstDecorator): name = "pytest_fixture" - def __init__( - self, - scope:str = "function", # passed to pytest.fixture - ): - self.scope = scope + def __init__(self, *args, **kwargs): + self._args = args + self._kwargs = kwargs def async_decorator(self): import pytest_asyncio - def decorator(func): - return pytest_asyncio.fixture(scope=self.scope)(func) - return decorator + return lambda f: pytest_asyncio.fixture(*self._args, **self._kwargs)(f) def sync_ast_transform(self, decorator, wrapped_node, transformers): import ast - decorator.func.value = ast.Name(id="pytest", ctx=ast.Load()) - decorator.func.attr = "fixture" - wrapped_node.decorator_list.append(decorator) - return wrapped_node + import copy + new_decorator = copy.deepcopy(decorator) + new_node = copy.deepcopy(wrapped_node) + new_decorator.func.value = ast.Name(id="pytest", ctx=ast.Load()) + new_decorator.func.attr = "fixture" + new_node.decorator_list.append(new_decorator) + return new_node From 8317973fdbca19c2a1acdc4fab6543b364cfd885 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Jul 2024 17:07:53 -0700 Subject: [PATCH 183/360] comments and clean up --- .cross_sync/transformers.py | 6 +- .../cloud/bigtable/data/_sync/cross_sync.py | 48 ---- .../data/_sync/cross_sync_decorators.py | 218 +++++++++++++----- 3 files changed, 157 insertions(+), 115 deletions(-) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 9a90d78d0..5ef81441e 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -145,7 +145,7 @@ def visit_AsyncFunctionDef(self, node): for decorator in found_list: try: handler = AstDecorator.get_for_node(decorator) - node = handler.sync_ast_transform(decorator, node, globals()) + node = handler.sync_ast_transform(node, globals()) if node is None: return None except ValueError: @@ -215,9 +215,7 @@ def visit_ClassDef(self, node): # write converted class details if not already present if sync_cls_name not in output_artifact.contained_classes: # transformation is handled in sync_ast_transform method of the decorator - converted = handler.sync_ast_transform( - decorator, node, globals() - ) + converted = handler.sync_ast_transform(node, globals()) output_artifact.converted_classes.append(converted) # handle file-level mypy ignores mypy_ignores = [ diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 47a5c7d89..e991d2757 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -43,54 +43,6 @@ T = TypeVar("T") -def pytest_mark_asyncio(func): - """ - Applies pytest.mark.asyncio to a function if pytest is installed, otherwise - returns the function as is - - Used to support CrossSync.pytest decorator, without requiring pytest to be installed - """ - try: - import pytest - - return pytest.mark.asyncio(func) - except ImportError: - return func - - -def pytest_asyncio_fixture(*args, **kwargs): - """ - Applies pytest.fixture to a function if pytest is installed, otherwise - returns the function as is - - Used to support CrossSync.pytest_fixture decorator, without requiring pytest to be installed - """ - import pytest_asyncio # type: ignore - - def decorator(func): - return pytest_asyncio.fixture(*args, **kwargs)(func) - - return decorator - - -def export_sync_impl(*args, **kwargs): - """ - Decorator implementation for CrossSync.export_sync - - When a called with add_mapping_for_name, CrossSync.add_mapping is called to - register the name as a CrossSync attribute - """ - new_mapping = kwargs.pop("add_mapping_for_name", None) - - def decorator(cls): - if new_mapping: - # add class to mappings if requested - CrossSync.add_mapping(new_mapping, cls) - return cls - - return decorator - - class CrossSync: # support CrossSync.is_async to check if the current environment is async is_async = True diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py index e78cf0b5d..982fbc497 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -11,101 +11,159 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence +from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import ast + from typing import Sequence, Callable, Any class AstDecorator: """ Helper class for CrossSync decorators used for guiding ast transformations. - These decorators provide arguments that are used during the code generation process, - but act as no-ops when encountered in live code + CrossSync decorations are accessed in two ways: + 1. The decorations are used directly as method decorations in the async client, + wrapping existing classes and methods + 2. The decorations are read back when processing the AST transformations when + generating sync code. + + This class allows the same decorator to be used in both contexts. + + Typically, AstDecorators act as a no-op in async code, and the arguments simply + provide configuration guidance for the sync code generation. """ @classmethod - def decorator(cls, *args, **kwargs): + def decorator(cls, *args, **kwargs) -> Callable[..., Any]: """ - Called when the decorator is used in code. + Provides a callable that can be used as a decorator function in async code + + AstDecorator.decorate is called by CrossSync when attaching decorators to + the CrossSync class. - Returns a no-op decorator function, or applies the async_impl decorator + This method creates a new instance of the class, using the arguments provided + to the decorator, and defers to the async_decorator method of the instance + to build the wrapper function. + + Arguments: + *args: arguments to the decorator + **kwargs: keyword arguments to the decorator """ - # check for decorators with no arguments + # decorators with no arguments will provide the function to be wrapped + # as the first argument. Pull it out if it exists func = None if len(args) == 1 and callable(args[0]): func = args[0] args = args[1:] - # create new instance from given arguments + # create new AstDecorator instance from given decorator arguments new_instance = cls(*args, **kwargs) + # build wrapper wrapper = new_instance.async_decorator() - # if we can, return single wrapped function - if func: + if wrapper is None: + # if no wrapper, return no-op decorator + return func or (lambda f: f) + elif func: + # if we can, return single wrapped function return wrapper(func) - # otherwise, return wrap function - return wrapper - - @classmethod - def _convert_ast_to_py(cls, ast_node): - """ - Helper to convert ast primitives to python primitives. Used when unwrapping kwargs - """ - import ast - - if isinstance(ast_node, ast.Constant): - return ast_node.value - if isinstance(ast_node, ast.List): - return [cls._convert_ast_to_py(node) for node in ast_node.elts] - if isinstance(ast_node, ast.Dict): - return { - cls._convert_ast_to_py(k): cls._convert_ast_to_py(v) - for k, v in zip(ast_node.keys, ast_node.values) - } - raise ValueError(f"Unsupported type {type(ast_node)}") + else: + # otherwise, return decorator function + return wrapper - def async_decorator(self): + def async_decorator(self) -> Callable[..., Any] | None: """ Decorator to apply the async_impl decorator to the wrapped function Default implementation is a no-op """ - def decorator(f): - return f - - return decorator + return None - def sync_ast_transform(self, decorator, wrapped_node, transformers): + def sync_ast_transform(self, wrapped_node:ast.AST, transformers_globals: dict[str, Any]) -> ast.AST | None: """ - When this decorator is encountered in the ast during sync generation, - apply this behavior + When this decorator is encountered in the ast during sync generation, this method is called + to transform the wrapped node. + + If None is returned, the node will be dropped from the output file. - Defaults to no-op + Args: + wrapped_node: ast node representing the wrapped function or class that is being wrapped + transformers_globals: the set of globals() from the transformers module. This is used to access + ast transformer classes that live outside the main codebase + Returns: + transformed ast node, or None if the node should be dropped """ return wrapped_node @classmethod - def get_for_node(cls, node): + def get_for_node(cls, node: ast.Call) -> "AstDecorator": + """ + Build an AstDecorator instance from an ast decorator node + + The right subclass is found by comparing the string representation of the + decorator name to the class name. (Both names are converted to lowercase and + underscores are removed for comparison). If a matching subclass is found, + a new instance is created with the provided arguments. + + Args: + node: ast.Call node representing the decorator + Returns: + AstDecorator instance corresponding to the decorator + Raises: + ValueError: if the decorator cannot be parsed + """ import ast if "CrossSync" in ast.dump(node): decorator_name = node.func.attr if hasattr(node, "func") else node.attr + formatted_name = decorator_name.replace("_", "").lower() got_kwargs = ( {kw.arg: cls._convert_ast_to_py(kw.value) for kw in node.keywords} if hasattr(node, "keywords") else {} ) + got_args = [cls._convert_ast_to_py(arg) for arg in node.args] if hasattr(node, "args") else [] for subclass in cls.__subclasses__(): - if subclass.name == decorator_name: - return subclass(**got_kwargs) - raise ValueError(f"Unknown decorator encountered") + if subclass.__name__.lower() == formatted_name: + return subclass(*got_args, **got_kwargs) + raise ValueError(f"Unknown decorator encountered: {decorator_name}") + raise ValueError("Not a CrossSync decorator") + @classmethod + def _convert_ast_to_py(cls, ast_node: ast.expr) -> Any: + """ + Helper to convert ast primitives to python primitives. Used when unwrapping arguments + """ + import ast + + if isinstance(ast_node, ast.Constant): + return ast_node.value + if isinstance(ast_node, ast.List): + return [cls._convert_ast_to_py(node) for node in ast_node.elts] + if isinstance(ast_node, ast.Dict): + return { + cls._convert_ast_to_py(k): cls._convert_ast_to_py(v) + for k, v in zip(ast_node.keys, ast_node.values) + } + raise ValueError(f"Unsupported type {type(ast_node)}") -class ExportSync(AstDecorator): - name = "export_sync" +class ExportSync(AstDecorator): + """ + Class decorator for marking async classes to be converted to sync classes + + Args: + path: path to output the generated sync class + replace_symbols: a dict of symbols and replacements to use when generating sync class + mypy_ignore: set of mypy errors to ignore in the generated file + include_file_imports: if True, include top-level imports from the file in the generated sync class + add_mapping_for_name: when given, will add a new attribute to CrossSync, so the original class and its sync version can be accessed from CrossSync. + """ def __init__( self, path:str, # path to output the generated sync class *, - replace_symbols:dict|None = None, # replace symbols in the generated sync class + replace_symbols:dict[str,str]|None = None, # replace symbols in the generated sync class mypy_ignore:Sequence[str] = (), # set of mypy errors to ignore include_file_imports:bool = True, # include imports from the file in the generated sync class add_mapping_for_name:str|None = None, # add a new attribute to CrossSync with the given name @@ -117,6 +175,9 @@ def __init__( self.add_mapping_for_name = add_mapping_for_name def async_decorator(self): + """ + Use async decorator as a hook to update CrossSync mappings + """ from .cross_sync import CrossSync new_mapping = self.add_mapping_for_name def decorator(cls): @@ -125,7 +186,7 @@ def decorator(cls): return cls return decorator - def sync_ast_transform(self, decorator, wrapped_node, transformers): + def sync_ast_transform(self, wrapped_node, transformers_globals): """ Transform async class into sync copy """ @@ -162,52 +223,72 @@ def sync_ast_transform(self, decorator, wrapped_node, transformers): # convert class contents replace_dict = self.replace_symbols or {} replace_dict.update({"CrossSync": f"CrossSync._SyncImpl"}) - wrapped_node = transformers["SymbolReplacer"](replace_dict).visit(wrapped_node) + wrapped_node = transformers_globals["SymbolReplacer"](replace_dict).visit(wrapped_node) # visit CrossSync method decorators - wrapped_node = transformers["CrossSyncMethodDecoratorHandler"]().visit(wrapped_node) + wrapped_node = transformers_globals["CrossSyncMethodDecoratorHandler"]().visit(wrapped_node) return wrapped_node class Convert(AstDecorator): + """ + Method decorator to mark async methods to be converted to sync methods - name = "convert" + Args: + sync_name: use a new name for the sync method + replace_symbols: a dict of symbols and replacements to use when generating sync method + """ def __init__( self, *, sync_name:str|None = None, # use a new name for the sync method - replace_symbols:dict = {} # replace symbols in the generated sync method + replace_symbols:dict[str,str] = {} # replace symbols in the generated sync method ): self.sync_name = sync_name self.replace_symbols = replace_symbols - def sync_ast_transform(self, decorator, wrapped_node, transformers): + def sync_ast_transform(self, wrapped_node, transformers_globals): + """ + Transform async method into sync + """ if self.sync_name: wrapped_node.name = self.sync_name if self.replace_symbols: - replacer = transformers["SymbolReplacer"] + replacer = transformers_globals["SymbolReplacer"] wrapped_node = replacer(self.replace_symbols).visit(wrapped_node) return wrapped_node class DropMethod(AstDecorator): + """ + Method decorator to drop async methods from the sync output + """ - name = "drop_method" - - def sync_ast_transform(self, decorator, wrapped_node, transformers): + def sync_ast_transform(self, wrapped_node, transformers_globals): + """ + Drop method from sync output + """ return None class Pytest(AstDecorator): + """ + Used in place of pytest.mark.asyncio to mark tests - name = "pytest" + Will be stripped from sync output + """ def async_decorator(self): import pytest return pytest.mark.asyncio class PytestFixture(AstDecorator): + """ + Used in place of pytest.fixture or pytest.mark.asyncio to mark fixtures - name = "pytest_fixture" + Args: + *args: all arguments to pass to pytest.fixture + **kwargs: all keyword arguments to pass to pytest.fixture + """ def __init__(self, *args, **kwargs): self._args = args @@ -217,12 +298,23 @@ def async_decorator(self): import pytest_asyncio return lambda f: pytest_asyncio.fixture(*self._args, **self._kwargs)(f) - def sync_ast_transform(self, decorator, wrapped_node, transformers): + def sync_ast_transform(self, wrapped_node, transformers_globals): import ast import copy - new_decorator = copy.deepcopy(decorator) new_node = copy.deepcopy(wrapped_node) - new_decorator.func.value = ast.Name(id="pytest", ctx=ast.Load()) - new_decorator.func.attr = "fixture" - new_node.decorator_list.append(new_decorator) + new_node.decorator_list.append( + ast.Call( + func=ast.Attribute( + value=ast.Name(id="pytest", ctx=ast.Load()), + attr="fixture", + ctx=ast.Load(), + ), + args=[ + ast.Constant(value=a) for a in self._args + ], + keywords=[ + ast.keyword(arg=k, value=ast.Constant(value=v)) for k, v in self._kwargs.items() + ] + ) + ) return new_node From f9dd41dbb4b6a4d787d9f7a42577ec9843701a2b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Jul 2024 17:08:56 -0700 Subject: [PATCH 184/360] ran blacken --- .../cloud/bigtable/data/_sync/cross_sync.py | 17 ++++- .../data/_sync/cross_sync_decorators.py | 63 +++++++++++++------ .../data/_async/test_mutations_batcher.py | 4 +- 3 files changed, 59 insertions(+), 25 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index e991d2757..6dcbd22fb 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -35,7 +35,14 @@ import queue import threading import time -from .cross_sync_decorators import AstDecorator, ExportSync, Convert, DropMethod, Pytest, PytestFixture +from .cross_sync_decorators import ( + AstDecorator, + ExportSync, + Convert, + DropMethod, + Pytest, + PytestFixture, +) if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -66,11 +73,13 @@ class CrossSync: Generator: TypeAlias = AsyncGenerator # decorators - export_sync = ExportSync.decorator # decorate classes to convert + export_sync = ExportSync.decorator # decorate classes to convert convert = Convert.decorator # decorate methods to convert from async to sync drop_method = DropMethod.decorator # decorate methods to remove from sync version pytest = Pytest.decorator # decorate test methods to run with pytest-asyncio - pytest_fixture = PytestFixture.decorator # decorate test methods to run with pytest fixture + pytest_fixture = ( + PytestFixture.decorator + ) # decorate test methods to run with pytest fixture # list of attributes that can be added to the CrossSync class at runtime _runtime_replacements: set[Any] = set() @@ -217,6 +226,7 @@ class _Sync_Impl: """ Provide sync versions of the async functions and types in CrossSync """ + is_async = False sleep = time.sleep @@ -243,6 +253,7 @@ def add_mapping_decorator(cls, name): def decorator(wrapped_cls): cls.add_mapping(name, wrapped_cls) return wrapped_cls + return decorator @classmethod diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py index 982fbc497..86b733b98 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -24,7 +24,7 @@ class AstDecorator: Helper class for CrossSync decorators used for guiding ast transformations. CrossSync decorations are accessed in two ways: - 1. The decorations are used directly as method decorations in the async client, + 1. The decorations are used directly as method decorations in the async client, wrapping existing classes and methods 2. The decorations are read back when processing the AST transformations when generating sync code. @@ -79,7 +79,9 @@ def async_decorator(self) -> Callable[..., Any] | None: """ return None - def sync_ast_transform(self, wrapped_node:ast.AST, transformers_globals: dict[str, Any]) -> ast.AST | None: + def sync_ast_transform( + self, wrapped_node: ast.AST, transformers_globals: dict[str, Any] + ) -> ast.AST | None: """ When this decorator is encountered in the ast during sync generation, this method is called to transform the wrapped node. @@ -102,7 +104,7 @@ def get_for_node(cls, node: ast.Call) -> "AstDecorator": The right subclass is found by comparing the string representation of the decorator name to the class name. (Both names are converted to lowercase and - underscores are removed for comparison). If a matching subclass is found, + underscores are removed for comparison). If a matching subclass is found, a new instance is created with the provided arguments. Args: @@ -113,6 +115,7 @@ def get_for_node(cls, node: ast.Call) -> "AstDecorator": ValueError: if the decorator cannot be parsed """ import ast + if "CrossSync" in ast.dump(node): decorator_name = node.func.attr if hasattr(node, "func") else node.attr formatted_name = decorator_name.replace("_", "").lower() @@ -121,7 +124,11 @@ def get_for_node(cls, node: ast.Call) -> "AstDecorator": if hasattr(node, "keywords") else {} ) - got_args = [cls._convert_ast_to_py(arg) for arg in node.args] if hasattr(node, "args") else [] + got_args = ( + [cls._convert_ast_to_py(arg) for arg in node.args] + if hasattr(node, "args") + else [] + ) for subclass in cls.__subclasses__(): if subclass.__name__.lower() == formatted_name: return subclass(*got_args, **got_kwargs) @@ -161,12 +168,14 @@ class ExportSync(AstDecorator): def __init__( self, - path:str, # path to output the generated sync class + path: str, # path to output the generated sync class *, - replace_symbols:dict[str,str]|None = None, # replace symbols in the generated sync class - mypy_ignore:Sequence[str] = (), # set of mypy errors to ignore - include_file_imports:bool = True, # include imports from the file in the generated sync class - add_mapping_for_name:str|None = None, # add a new attribute to CrossSync with the given name + replace_symbols: dict[str, str] + | None = None, # replace symbols in the generated sync class + mypy_ignore: Sequence[str] = (), # set of mypy errors to ignore + include_file_imports: bool = True, # include imports from the file in the generated sync class + add_mapping_for_name: str + | None = None, # add a new attribute to CrossSync with the given name ): self.path = path self.replace_symbols = replace_symbols @@ -179,11 +188,14 @@ def async_decorator(self): Use async decorator as a hook to update CrossSync mappings """ from .cross_sync import CrossSync + new_mapping = self.add_mapping_for_name + def decorator(cls): if new_mapping: CrossSync.add_mapping(new_mapping, cls) return cls + return decorator def sync_ast_transform(self, wrapped_node, transformers_globals): @@ -192,8 +204,11 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): """ import ast import copy + if not self.path: - raise ValueError(f"{wrapped_node.name} has no path specified in export_sync decorator") + raise ValueError( + f"{wrapped_node.name} has no path specified in export_sync decorator" + ) # copy wrapped node wrapped_node = copy.deepcopy(wrapped_node) # update name @@ -223,9 +238,13 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): # convert class contents replace_dict = self.replace_symbols or {} replace_dict.update({"CrossSync": f"CrossSync._SyncImpl"}) - wrapped_node = transformers_globals["SymbolReplacer"](replace_dict).visit(wrapped_node) + wrapped_node = transformers_globals["SymbolReplacer"](replace_dict).visit( + wrapped_node + ) # visit CrossSync method decorators - wrapped_node = transformers_globals["CrossSyncMethodDecoratorHandler"]().visit(wrapped_node) + wrapped_node = transformers_globals["CrossSyncMethodDecoratorHandler"]().visit( + wrapped_node + ) return wrapped_node @@ -241,8 +260,10 @@ class Convert(AstDecorator): def __init__( self, *, - sync_name:str|None = None, # use a new name for the sync method - replace_symbols:dict[str,str] = {} # replace symbols in the generated sync method + sync_name: str | None = None, # use a new name for the sync method + replace_symbols: dict[ + str, str + ] = {}, # replace symbols in the generated sync method ): self.sync_name = sync_name self.replace_symbols = replace_symbols @@ -270,6 +291,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): """ return None + class Pytest(AstDecorator): """ Used in place of pytest.mark.asyncio to mark tests @@ -279,8 +301,10 @@ class Pytest(AstDecorator): def async_decorator(self): import pytest + return pytest.mark.asyncio + class PytestFixture(AstDecorator): """ Used in place of pytest.fixture or pytest.mark.asyncio to mark fixtures @@ -296,11 +320,13 @@ def __init__(self, *args, **kwargs): def async_decorator(self): import pytest_asyncio + return lambda f: pytest_asyncio.fixture(*self._args, **self._kwargs)(f) def sync_ast_transform(self, wrapped_node, transformers_globals): import ast import copy + new_node = copy.deepcopy(wrapped_node) new_node.decorator_list.append( ast.Call( @@ -309,12 +335,11 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): attr="fixture", ctx=ast.Load(), ), - args=[ - ast.Constant(value=a) for a in self._args - ], + args=[ast.Constant(value=a) for a in self._args], keywords=[ - ast.keyword(arg=k, value=ast.Constant(value=v)) for k, v in self._kwargs.items() - ] + ast.keyword(arg=k, value=ast.Constant(value=v)) + for k, v in self._kwargs.items() + ], ) ) return new_node diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 32121d02b..fcd425273 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -955,9 +955,7 @@ async def test__execute_mutate_rows_returns_errors(self): FailedMutationEntryError, ) - with mock.patch.object( - CrossSync._MutateRowsOperation, "start" - ) as mutate_rows: + with mock.patch.object(CrossSync._MutateRowsOperation, "start") as mutate_rows: err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) From 22b093f468af6ff5e0de7fb2010ced5b23706e15 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Jul 2024 17:14:51 -0700 Subject: [PATCH 185/360] moved ast decorators into new file --- .../cloud/bigtable/data/_sync/cross_sync.py | 225 +----------- .../data/_sync/cross_sync_decorators.py | 320 ++++++++++++++++++ 2 files changed, 336 insertions(+), 209 deletions(-) create mode 100644 google/cloud/bigtable/data/_sync/cross_sync_decorators.py diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 92b7963c8..03f3904af 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -20,6 +20,7 @@ Callable, Coroutine, Sequence, + Union, AsyncIterable, AsyncIterator, AsyncGenerator, @@ -31,6 +32,10 @@ import sys import concurrent.futures import google.api_core.retry as retries +import queue +import threading +import time +from .cross_sync_decorators import AstDecorator, ExportSync, Convert, DropMethod, Pytest, PytestFixture if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -38,180 +43,7 @@ T = TypeVar("T") -def pytest_mark_asyncio(func): - """ - Applies pytest.mark.asyncio to a function if pytest is installed, otherwise - returns the function as is - - Used to support CrossSync.pytest decorator, without requiring pytest to be installed - """ - try: - import pytest - - return pytest.mark.asyncio(func) - except ImportError: - return func - - -def pytest_asyncio_fixture(*args, **kwargs): - """ - Applies pytest.fixture to a function if pytest is installed, otherwise - returns the function as is - - Used to support CrossSync.pytest_fixture decorator, without requiring pytest to be installed - """ - import pytest_asyncio # type: ignore - - def decorator(func): - return pytest_asyncio.fixture(*args, **kwargs)(func) - - return decorator - - -def export_sync_impl(*args, **kwargs): - """ - Decorator implementation for CrossSync.export_sync - - When a called with add_mapping_for_name, CrossSync.add_mapping is called to - register the name as a CrossSync attribute - """ - new_mapping = kwargs.pop("add_mapping_for_name", None) - - def decorator(cls): - if new_mapping: - # add class to mappings if requested - CrossSync.add_mapping(new_mapping, cls) - return cls - - return decorator - - -class AstDecorator: - """ - Helper class for CrossSync decorators used for guiding ast transformations. - - These decorators provide arguments that are used during the code generation process, - but act as no-ops when encountered in live code - - Args: - attr_name: name of the attribute to attach to the CrossSync class - e.g. pytest for CrossSync.pytest - required_keywords: list of required keyword arguments for the decorator. - If the decorator is used without these arguments, a ValueError is - raised during code generation - async_impl: If given, the async code will apply this decorator to its - wrapped function at runtime. If not given, the decorator will be a no-op - **default_kwargs: any kwargs passed define the valid arguments when using the decorator. - The value of each kwarg is the default value for the argument. - """ - - def __init__( - self, - attr_name, - required_keywords=(), - async_impl=None, - **default_kwargs, - ): - self.name = attr_name - self.required_kwargs = required_keywords - self.default_kwargs = default_kwargs - self.all_valid_keys = [*required_keywords, *default_kwargs.keys()] - self.async_impl = async_impl - - def __call__(self, *args, **kwargs): - """ - Called when the decorator is used in code. - - Returns a no-op decorator function, or applies the async_impl decorator - """ - # raise error if invalid kwargs are passed - for kwarg in kwargs: - if kwarg not in self.all_valid_keys: - raise ValueError(f"Invalid keyword argument: {kwarg}") - # if async_impl is provided, use the given decorator function - if self.async_impl: - return self.async_impl(*args, **{**self.default_kwargs, **kwargs}) - # if no arguments, args[0] will hold the function to be decorated - # return the function as is - if len(args) == 1 and callable(args[0]): - return args[0] - - # if arguments are provided, return a no-op decorator function - def decorator(func): - return func - - return decorator - - def parse_ast_keywords(self, node): - """ - When this decorator is encountered in the ast during sync generation, parse the - keyword arguments back from ast nodes to python primitives - - Return a full set of kwargs, using default values for missing arguments - """ - got_kwargs = ( - {kw.arg: self._convert_ast_to_py(kw.value) for kw in node.keywords} - if hasattr(node, "keywords") - else {} - ) - for key in got_kwargs.keys(): - if key not in self.all_valid_keys: - raise ValueError(f"Invalid keyword argument: {key}") - for key in self.required_kwargs: - if key not in got_kwargs: - raise ValueError(f"Missing required keyword argument: {key}") - return {**self.default_kwargs, **got_kwargs} - - def _convert_ast_to_py(self, ast_node): - """ - Helper to convert ast primitives to python primitives. Used when unwrapping kwargs - """ - import ast - - if isinstance(ast_node, ast.Constant): - return ast_node.value - if isinstance(ast_node, ast.List): - return [self._convert_ast_to_py(node) for node in ast_node.elts] - if isinstance(ast_node, ast.Dict): - return { - self._convert_ast_to_py(k): self._convert_ast_to_py(v) - for k, v in zip(ast_node.keys, ast_node.values) - } - raise ValueError(f"Unsupported type {type(ast_node)}") - - def _node_eq(self, node): - """ - Check if the given ast node is a call to this decorator - """ - import ast - - if "CrossSync" in ast.dump(node): - decorator_type = node.func.attr if hasattr(node, "func") else node.attr - if decorator_type == self.name: - return True - return False - - def __eq__(self, other): - """ - Helper to support == comparison with ast nodes - """ - return self._node_eq(other) - - -class _DecoratorMeta(type): - """ - Metaclass to attach AstDecorator objects in internal self._decorators - as attributes - """ - - def __getattr__(self, name): - for decorator in self._decorators: - if name == decorator.name: - return decorator - raise AttributeError(f"CrossSync has no attribute {name}") - - -class CrossSync(metaclass=_DecoratorMeta): +class CrossSync: # support CrossSync.is_async to check if the current environment is async is_async = True @@ -233,6 +65,16 @@ class CrossSync(metaclass=_DecoratorMeta): Iterator: TypeAlias = AsyncIterator Generator: TypeAlias = AsyncGenerator + # decorators + export_sync = ExportSync.decorator # decorate classes to convert + convert = Convert.decorator # decorate methods to convert from async to sync + drop_method = DropMethod.decorator # decorate methods to remove from sync version + pytest = Pytest.decorator # decorate test methods to run with pytest-asyncio + pytest_fixture = PytestFixture.decorator # decorate test methods to run with pytest fixture + + # list of attributes that can be added to the CrossSync class at runtime + _runtime_replacements: set[Any] = set() + @classmethod def add_mapping(cls, name, value): """ @@ -247,41 +89,6 @@ def add_mapping(cls, name, value): raise AttributeError(f"Conflicting assignments for CrossSync.{name}") setattr(cls, name, value) - # list of decorators that can be applied to classes and methods to guide code generation - _decorators: list[AstDecorator] = [ - AstDecorator( - "export_sync", # decorate classes to convert - required_keywords=["path"], # otput path for generated sync class - async_impl=export_sync_impl, # apply this decorator to the function at runtime - replace_symbols={}, # replace specific symbols across entire class - mypy_ignore=(), # set of mypy error codes to ignore in output file - include_file_imports=True, # when True, import statements from top of file will be included in output file - add_mapping_for_name=None, # add a new attribute to CrossSync class with the given name - ), - AstDecorator( - "convert", # decorate methods to convert from async to sync - sync_name=None, # use a new name for the sync class - replace_symbols={}, # replace specific symbols within the function - ), - AstDecorator( - "drop_method" - ), # decorate methods to drop in sync version of class - AstDecorator( - "pytest", async_impl=pytest_mark_asyncio - ), # decorate test methods to run with pytest-asyncio - AstDecorator( - "pytest_fixture", # decorate test methods to run with pytest fixture - async_impl=pytest_asyncio_fixture, - scope="function", - params=None, - autouse=False, - ids=None, - name=None, - ), - ] - # list of attributes that can be added to the CrossSync class at runtime - _runtime_replacements: set[Any] = set() - @classmethod def Mock(cls, *args, **kwargs): """ diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py new file mode 100644 index 000000000..982fbc497 --- /dev/null +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -0,0 +1,320 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import ast + from typing import Sequence, Callable, Any + + +class AstDecorator: + """ + Helper class for CrossSync decorators used for guiding ast transformations. + + CrossSync decorations are accessed in two ways: + 1. The decorations are used directly as method decorations in the async client, + wrapping existing classes and methods + 2. The decorations are read back when processing the AST transformations when + generating sync code. + + This class allows the same decorator to be used in both contexts. + + Typically, AstDecorators act as a no-op in async code, and the arguments simply + provide configuration guidance for the sync code generation. + """ + + @classmethod + def decorator(cls, *args, **kwargs) -> Callable[..., Any]: + """ + Provides a callable that can be used as a decorator function in async code + + AstDecorator.decorate is called by CrossSync when attaching decorators to + the CrossSync class. + + This method creates a new instance of the class, using the arguments provided + to the decorator, and defers to the async_decorator method of the instance + to build the wrapper function. + + Arguments: + *args: arguments to the decorator + **kwargs: keyword arguments to the decorator + """ + # decorators with no arguments will provide the function to be wrapped + # as the first argument. Pull it out if it exists + func = None + if len(args) == 1 and callable(args[0]): + func = args[0] + args = args[1:] + # create new AstDecorator instance from given decorator arguments + new_instance = cls(*args, **kwargs) + # build wrapper + wrapper = new_instance.async_decorator() + if wrapper is None: + # if no wrapper, return no-op decorator + return func or (lambda f: f) + elif func: + # if we can, return single wrapped function + return wrapper(func) + else: + # otherwise, return decorator function + return wrapper + + def async_decorator(self) -> Callable[..., Any] | None: + """ + Decorator to apply the async_impl decorator to the wrapped function + + Default implementation is a no-op + """ + return None + + def sync_ast_transform(self, wrapped_node:ast.AST, transformers_globals: dict[str, Any]) -> ast.AST | None: + """ + When this decorator is encountered in the ast during sync generation, this method is called + to transform the wrapped node. + + If None is returned, the node will be dropped from the output file. + + Args: + wrapped_node: ast node representing the wrapped function or class that is being wrapped + transformers_globals: the set of globals() from the transformers module. This is used to access + ast transformer classes that live outside the main codebase + Returns: + transformed ast node, or None if the node should be dropped + """ + return wrapped_node + + @classmethod + def get_for_node(cls, node: ast.Call) -> "AstDecorator": + """ + Build an AstDecorator instance from an ast decorator node + + The right subclass is found by comparing the string representation of the + decorator name to the class name. (Both names are converted to lowercase and + underscores are removed for comparison). If a matching subclass is found, + a new instance is created with the provided arguments. + + Args: + node: ast.Call node representing the decorator + Returns: + AstDecorator instance corresponding to the decorator + Raises: + ValueError: if the decorator cannot be parsed + """ + import ast + if "CrossSync" in ast.dump(node): + decorator_name = node.func.attr if hasattr(node, "func") else node.attr + formatted_name = decorator_name.replace("_", "").lower() + got_kwargs = ( + {kw.arg: cls._convert_ast_to_py(kw.value) for kw in node.keywords} + if hasattr(node, "keywords") + else {} + ) + got_args = [cls._convert_ast_to_py(arg) for arg in node.args] if hasattr(node, "args") else [] + for subclass in cls.__subclasses__(): + if subclass.__name__.lower() == formatted_name: + return subclass(*got_args, **got_kwargs) + raise ValueError(f"Unknown decorator encountered: {decorator_name}") + raise ValueError("Not a CrossSync decorator") + + @classmethod + def _convert_ast_to_py(cls, ast_node: ast.expr) -> Any: + """ + Helper to convert ast primitives to python primitives. Used when unwrapping arguments + """ + import ast + + if isinstance(ast_node, ast.Constant): + return ast_node.value + if isinstance(ast_node, ast.List): + return [cls._convert_ast_to_py(node) for node in ast_node.elts] + if isinstance(ast_node, ast.Dict): + return { + cls._convert_ast_to_py(k): cls._convert_ast_to_py(v) + for k, v in zip(ast_node.keys, ast_node.values) + } + raise ValueError(f"Unsupported type {type(ast_node)}") + + +class ExportSync(AstDecorator): + """ + Class decorator for marking async classes to be converted to sync classes + + Args: + path: path to output the generated sync class + replace_symbols: a dict of symbols and replacements to use when generating sync class + mypy_ignore: set of mypy errors to ignore in the generated file + include_file_imports: if True, include top-level imports from the file in the generated sync class + add_mapping_for_name: when given, will add a new attribute to CrossSync, so the original class and its sync version can be accessed from CrossSync. + """ + + def __init__( + self, + path:str, # path to output the generated sync class + *, + replace_symbols:dict[str,str]|None = None, # replace symbols in the generated sync class + mypy_ignore:Sequence[str] = (), # set of mypy errors to ignore + include_file_imports:bool = True, # include imports from the file in the generated sync class + add_mapping_for_name:str|None = None, # add a new attribute to CrossSync with the given name + ): + self.path = path + self.replace_symbols = replace_symbols + self.mypy_ignore = mypy_ignore + self.include_file_imports = include_file_imports + self.add_mapping_for_name = add_mapping_for_name + + def async_decorator(self): + """ + Use async decorator as a hook to update CrossSync mappings + """ + from .cross_sync import CrossSync + new_mapping = self.add_mapping_for_name + def decorator(cls): + if new_mapping: + CrossSync.add_mapping(new_mapping, cls) + return cls + return decorator + + def sync_ast_transform(self, wrapped_node, transformers_globals): + """ + Transform async class into sync copy + """ + import ast + import copy + if not self.path: + raise ValueError(f"{wrapped_node.name} has no path specified in export_sync decorator") + # copy wrapped node + wrapped_node = copy.deepcopy(wrapped_node) + # update name + sync_cls_name = self.path.rsplit(".", 1)[-1] + orig_name = wrapped_node.name + wrapped_node.name = sync_cls_name + # strip CrossSync decorators + if hasattr(wrapped_node, "decorator_list"): + wrapped_node.decorator_list = [ + d for d in wrapped_node.decorator_list if "CrossSync" not in ast.dump(d) + ] + # add mapping decorator if needed + if self.add_mapping_for_name: + wrapped_node.decorator_list.append( + ast.Call( + func=ast.Attribute( + value=ast.Name(id="CrossSync", ctx=ast.Load()), + attr="add_mapping", + ctx=ast.Load(), + ), + args=[ + ast.Constant(value=self.add_mapping_for_name), + ], + keywords=[], + ) + ) + # convert class contents + replace_dict = self.replace_symbols or {} + replace_dict.update({"CrossSync": f"CrossSync._SyncImpl"}) + wrapped_node = transformers_globals["SymbolReplacer"](replace_dict).visit(wrapped_node) + # visit CrossSync method decorators + wrapped_node = transformers_globals["CrossSyncMethodDecoratorHandler"]().visit(wrapped_node) + return wrapped_node + + +class Convert(AstDecorator): + """ + Method decorator to mark async methods to be converted to sync methods + + Args: + sync_name: use a new name for the sync method + replace_symbols: a dict of symbols and replacements to use when generating sync method + """ + + def __init__( + self, + *, + sync_name:str|None = None, # use a new name for the sync method + replace_symbols:dict[str,str] = {} # replace symbols in the generated sync method + ): + self.sync_name = sync_name + self.replace_symbols = replace_symbols + + def sync_ast_transform(self, wrapped_node, transformers_globals): + """ + Transform async method into sync + """ + if self.sync_name: + wrapped_node.name = self.sync_name + if self.replace_symbols: + replacer = transformers_globals["SymbolReplacer"] + wrapped_node = replacer(self.replace_symbols).visit(wrapped_node) + return wrapped_node + + +class DropMethod(AstDecorator): + """ + Method decorator to drop async methods from the sync output + """ + + def sync_ast_transform(self, wrapped_node, transformers_globals): + """ + Drop method from sync output + """ + return None + +class Pytest(AstDecorator): + """ + Used in place of pytest.mark.asyncio to mark tests + + Will be stripped from sync output + """ + + def async_decorator(self): + import pytest + return pytest.mark.asyncio + +class PytestFixture(AstDecorator): + """ + Used in place of pytest.fixture or pytest.mark.asyncio to mark fixtures + + Args: + *args: all arguments to pass to pytest.fixture + **kwargs: all keyword arguments to pass to pytest.fixture + """ + + def __init__(self, *args, **kwargs): + self._args = args + self._kwargs = kwargs + + def async_decorator(self): + import pytest_asyncio + return lambda f: pytest_asyncio.fixture(*self._args, **self._kwargs)(f) + + def sync_ast_transform(self, wrapped_node, transformers_globals): + import ast + import copy + new_node = copy.deepcopy(wrapped_node) + new_node.decorator_list.append( + ast.Call( + func=ast.Attribute( + value=ast.Name(id="pytest", ctx=ast.Load()), + attr="fixture", + ctx=ast.Load(), + ), + args=[ + ast.Constant(value=a) for a in self._args + ], + keywords=[ + ast.keyword(arg=k, value=ast.Constant(value=v)) for k, v in self._kwargs.items() + ] + ) + ) + return new_node From 8d13c5e8f37964f31edca9a5ac6e7c7d2666b241 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Jul 2024 17:18:41 -0700 Subject: [PATCH 186/360] ran blacken --- .../cloud/bigtable/data/_sync/cross_sync.py | 15 ++++- .../data/_sync/cross_sync_decorators.py | 63 +++++++++++++------ .../data/_async/test_mutations_batcher.py | 4 +- 3 files changed, 57 insertions(+), 25 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 03f3904af..862286e3d 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -35,7 +35,14 @@ import queue import threading import time -from .cross_sync_decorators import AstDecorator, ExportSync, Convert, DropMethod, Pytest, PytestFixture +from .cross_sync_decorators import ( + AstDecorator, + ExportSync, + Convert, + DropMethod, + Pytest, + PytestFixture, +) if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -66,11 +73,13 @@ class CrossSync: Generator: TypeAlias = AsyncGenerator # decorators - export_sync = ExportSync.decorator # decorate classes to convert + export_sync = ExportSync.decorator # decorate classes to convert convert = Convert.decorator # decorate methods to convert from async to sync drop_method = DropMethod.decorator # decorate methods to remove from sync version pytest = Pytest.decorator # decorate test methods to run with pytest-asyncio - pytest_fixture = PytestFixture.decorator # decorate test methods to run with pytest fixture + pytest_fixture = ( + PytestFixture.decorator + ) # decorate test methods to run with pytest fixture # list of attributes that can be added to the CrossSync class at runtime _runtime_replacements: set[Any] = set() diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py index 982fbc497..86b733b98 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -24,7 +24,7 @@ class AstDecorator: Helper class for CrossSync decorators used for guiding ast transformations. CrossSync decorations are accessed in two ways: - 1. The decorations are used directly as method decorations in the async client, + 1. The decorations are used directly as method decorations in the async client, wrapping existing classes and methods 2. The decorations are read back when processing the AST transformations when generating sync code. @@ -79,7 +79,9 @@ def async_decorator(self) -> Callable[..., Any] | None: """ return None - def sync_ast_transform(self, wrapped_node:ast.AST, transformers_globals: dict[str, Any]) -> ast.AST | None: + def sync_ast_transform( + self, wrapped_node: ast.AST, transformers_globals: dict[str, Any] + ) -> ast.AST | None: """ When this decorator is encountered in the ast during sync generation, this method is called to transform the wrapped node. @@ -102,7 +104,7 @@ def get_for_node(cls, node: ast.Call) -> "AstDecorator": The right subclass is found by comparing the string representation of the decorator name to the class name. (Both names are converted to lowercase and - underscores are removed for comparison). If a matching subclass is found, + underscores are removed for comparison). If a matching subclass is found, a new instance is created with the provided arguments. Args: @@ -113,6 +115,7 @@ def get_for_node(cls, node: ast.Call) -> "AstDecorator": ValueError: if the decorator cannot be parsed """ import ast + if "CrossSync" in ast.dump(node): decorator_name = node.func.attr if hasattr(node, "func") else node.attr formatted_name = decorator_name.replace("_", "").lower() @@ -121,7 +124,11 @@ def get_for_node(cls, node: ast.Call) -> "AstDecorator": if hasattr(node, "keywords") else {} ) - got_args = [cls._convert_ast_to_py(arg) for arg in node.args] if hasattr(node, "args") else [] + got_args = ( + [cls._convert_ast_to_py(arg) for arg in node.args] + if hasattr(node, "args") + else [] + ) for subclass in cls.__subclasses__(): if subclass.__name__.lower() == formatted_name: return subclass(*got_args, **got_kwargs) @@ -161,12 +168,14 @@ class ExportSync(AstDecorator): def __init__( self, - path:str, # path to output the generated sync class + path: str, # path to output the generated sync class *, - replace_symbols:dict[str,str]|None = None, # replace symbols in the generated sync class - mypy_ignore:Sequence[str] = (), # set of mypy errors to ignore - include_file_imports:bool = True, # include imports from the file in the generated sync class - add_mapping_for_name:str|None = None, # add a new attribute to CrossSync with the given name + replace_symbols: dict[str, str] + | None = None, # replace symbols in the generated sync class + mypy_ignore: Sequence[str] = (), # set of mypy errors to ignore + include_file_imports: bool = True, # include imports from the file in the generated sync class + add_mapping_for_name: str + | None = None, # add a new attribute to CrossSync with the given name ): self.path = path self.replace_symbols = replace_symbols @@ -179,11 +188,14 @@ def async_decorator(self): Use async decorator as a hook to update CrossSync mappings """ from .cross_sync import CrossSync + new_mapping = self.add_mapping_for_name + def decorator(cls): if new_mapping: CrossSync.add_mapping(new_mapping, cls) return cls + return decorator def sync_ast_transform(self, wrapped_node, transformers_globals): @@ -192,8 +204,11 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): """ import ast import copy + if not self.path: - raise ValueError(f"{wrapped_node.name} has no path specified in export_sync decorator") + raise ValueError( + f"{wrapped_node.name} has no path specified in export_sync decorator" + ) # copy wrapped node wrapped_node = copy.deepcopy(wrapped_node) # update name @@ -223,9 +238,13 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): # convert class contents replace_dict = self.replace_symbols or {} replace_dict.update({"CrossSync": f"CrossSync._SyncImpl"}) - wrapped_node = transformers_globals["SymbolReplacer"](replace_dict).visit(wrapped_node) + wrapped_node = transformers_globals["SymbolReplacer"](replace_dict).visit( + wrapped_node + ) # visit CrossSync method decorators - wrapped_node = transformers_globals["CrossSyncMethodDecoratorHandler"]().visit(wrapped_node) + wrapped_node = transformers_globals["CrossSyncMethodDecoratorHandler"]().visit( + wrapped_node + ) return wrapped_node @@ -241,8 +260,10 @@ class Convert(AstDecorator): def __init__( self, *, - sync_name:str|None = None, # use a new name for the sync method - replace_symbols:dict[str,str] = {} # replace symbols in the generated sync method + sync_name: str | None = None, # use a new name for the sync method + replace_symbols: dict[ + str, str + ] = {}, # replace symbols in the generated sync method ): self.sync_name = sync_name self.replace_symbols = replace_symbols @@ -270,6 +291,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): """ return None + class Pytest(AstDecorator): """ Used in place of pytest.mark.asyncio to mark tests @@ -279,8 +301,10 @@ class Pytest(AstDecorator): def async_decorator(self): import pytest + return pytest.mark.asyncio + class PytestFixture(AstDecorator): """ Used in place of pytest.fixture or pytest.mark.asyncio to mark fixtures @@ -296,11 +320,13 @@ def __init__(self, *args, **kwargs): def async_decorator(self): import pytest_asyncio + return lambda f: pytest_asyncio.fixture(*self._args, **self._kwargs)(f) def sync_ast_transform(self, wrapped_node, transformers_globals): import ast import copy + new_node = copy.deepcopy(wrapped_node) new_node.decorator_list.append( ast.Call( @@ -309,12 +335,11 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): attr="fixture", ctx=ast.Load(), ), - args=[ - ast.Constant(value=a) for a in self._args - ], + args=[ast.Constant(value=a) for a in self._args], keywords=[ - ast.keyword(arg=k, value=ast.Constant(value=v)) for k, v in self._kwargs.items() - ] + ast.keyword(arg=k, value=ast.Constant(value=v)) + for k, v in self._kwargs.items() + ], ) ) return new_node diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 32121d02b..fcd425273 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -955,9 +955,7 @@ async def test__execute_mutate_rows_returns_errors(self): FailedMutationEntryError, ) - with mock.patch.object( - CrossSync._MutateRowsOperation, "start" - ) as mutate_rows: + with mock.patch.object(CrossSync._MutateRowsOperation, "start") as mutate_rows: err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) From 769cac1dec3aeeb74ef64a105a293a9d070f70c3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Jul 2024 17:30:26 -0700 Subject: [PATCH 187/360] changed sync impl name --- google/cloud/bigtable/data/_sync/cross_sync_decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py index 86b733b98..a0c991f1a 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -237,7 +237,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): ) # convert class contents replace_dict = self.replace_symbols or {} - replace_dict.update({"CrossSync": f"CrossSync._SyncImpl"}) + replace_dict.update({"CrossSync": f"CrossSync._Sync_Impl"}) wrapped_node = transformers_globals["SymbolReplacer"](replace_dict).visit( wrapped_node ) From f87b832e70c2805edbd8f008dc232eaf321fc430 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Jul 2024 17:30:36 -0700 Subject: [PATCH 188/360] fixed mapping function --- google/cloud/bigtable/data/_sync/cross_sync_decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py index a0c991f1a..0d212ae72 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -226,7 +226,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): ast.Call( func=ast.Attribute( value=ast.Name(id="CrossSync", ctx=ast.Load()), - attr="add_mapping", + attr="add_mapping_decorator", ctx=ast.Load(), ), args=[ From f2b6d086fdf69e0de4fff8c9848b9e2f3df25eff Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Jul 2024 17:30:46 -0700 Subject: [PATCH 189/360] convert to async with pytest mark --- google/cloud/bigtable/data/_sync/cross_sync_decorators.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py index 0d212ae72..10e6f7c5b 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -304,6 +304,14 @@ def async_decorator(self): return pytest.mark.asyncio + def sync_ast_transform(self, wrapped_node, transformers_globals): + """ + convert async to sync + """ + import ast + converted = transformers_globals["AsyncToSync"]().visit(wrapped_node) + return converted + class PytestFixture(AstDecorator): """ From 3ed5935cf3428de77744ac36605e83732a1033cb Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 17 Jul 2024 16:40:54 -0700 Subject: [PATCH 190/360] convert changes async to sync def by default --- .../cloud/bigtable/data/_sync/cross_sync_decorators.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py index 10e6f7c5b..3ce8a8d16 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -272,8 +272,17 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): """ Transform async method into sync """ + import ast + # replace async function with sync function + wrapped_node = ast.copy_location( + ast.FunctionDef(wrapped_node.name, wrapped_node.args, + wrapped_node.body, wrapped_node.decorator_list, wrapped_node.returns, + ), wrapped_node, + ) + # update name if specified if self.sync_name: wrapped_node.name = self.sync_name + # update arbitrary symbols if specified if self.replace_symbols: replacer = transformers_globals["SymbolReplacer"] wrapped_node = replacer(self.replace_symbols).visit(wrapped_node) From d0ba7b0b763b7b3ec50c11c19d0a20dcbb850a16 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 17 Jul 2024 17:28:30 -0700 Subject: [PATCH 191/360] added rm_aio for stripping asyncio keywords --- .cross_sync/transformers.py | 56 +++++ .../bigtable/data/_async/_mutate_rows.py | 14 +- .../cloud/bigtable/data/_async/_read_rows.py | 8 +- google/cloud/bigtable/data/_async/client.py | 206 +++++++++++------- .../bigtable/data/_async/mutations_batcher.py | 36 +-- .../cloud/bigtable/data/_sync/cross_sync.py | 4 + .../data/_sync/cross_sync_decorators.py | 14 +- 7 files changed, 228 insertions(+), 110 deletions(-) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 5ef81441e..6b48421d3 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -129,6 +129,62 @@ def visit_ListComp(self, node): generator.is_async = False return self.generic_visit(node) +class RmAioFunctions(ast.NodeTransformer): + """ + Visits all calls marked with CrossSync.rm_aio, and removes asyncio keywords + """ + + def __init__(self): + self.to_sync = AsyncToSync() + + def visit_Call(self, node): + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and \ + node.func.attr == "rm_aio" and "CrossSync" in node.func.value.id: + return self.visit(self.to_sync.visit(node.args[0])) + return self.generic_visit(node) + + def visit_AsyncWith(self, node): + """ + Async with statements are not fully wrapped by calls + """ + found_rmaio = False + new_items = [] + for item in node.items: + if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute) and isinstance(item.context_expr.func.value, ast.Name) and \ + item.context_expr.func.attr == "rm_aio" and "CrossSync" in item.context_expr.func.value.id: + found_rmaio = True + new_items.append(item.context_expr.args[0]) + else: + new_items.append(item) + if found_rmaio: + new_node = ast.copy_location( + ast.With( + [self.generic_visit(item) for item in new_items], + [self.generic_visit(stmt) for stmt in node.body], + ), + node, + ) + return self.generic_visit(new_node) + return self.generic_visit(node) + + def visit_AsyncFor(self, node): + """ + Async for statements are not fully wrapped by calls + """ + it = node.iter + if isinstance(it, ast.Call) and isinstance(it.func, ast.Attribute) and isinstance(it.func.value, ast.Name) and \ + it.func.attr == "rm_aio" and "CrossSync" in it.func.value.id: + return ast.copy_location( + ast.For( + self.visit(node.target), + self.visit(node.iter.args[0]), + [self.visit(stmt) for stmt in node.body], + [self.visit(stmt) for stmt in node.orelse], + ), + node, + ) + return self.generic_visit(node) + class CrossSyncMethodDecoratorHandler(ast.NodeTransformer): """ diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 87f9c25d4..e62d43397 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -124,7 +124,7 @@ async def start(self): """ try: # trigger mutate_rows - await self._operation() + CrossSync.rm_aio(await self._operation()) except Exception as exc: # exceptions raised by retryable are added to the list of exceptions for all unfinalized mutations incomplete_indices = self.remaining_indices.copy() @@ -172,12 +172,14 @@ async def _run_attempt(self): return # make gapic request try: - result_generator = await self._gapic_fn( - timeout=next(self.timeout_generator), - entries=request_entries, - retry=None, + result_generator = CrossSync.rm_aio( + await self._gapic_fn( + timeout=next(self.timeout_generator), + entries=request_entries, + retry=None, + ) ) - async for result_list in result_generator: + async for result_list in CrossSync.rm_aio(result_generator): for result in result_list.entries: # convert sub-request index to global index orig_idx = active_request_indices[result.index] diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 7e5c5893e..2fe48e9e9 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -164,7 +164,7 @@ async def chunk_stream( Yields: ReadRowsResponsePB.CellChunk: the next chunk in the stream """ - async for resp in await stream: + async for resp in CrossSync.rm_aio(await stream): # extract proto from proto-plus wrapper resp = resp._pb @@ -225,7 +225,7 @@ async def merge_rows( # For each row while True: try: - c = await it.__anext__() + c = CrossSync.rm_aio(await it.__anext__()) except CrossSync.StopIteration: # stream complete return @@ -274,7 +274,7 @@ async def merge_rows( buffer = [value] while c.value_size > 0: # throws when premature end - c = await it.__anext__() + c = CrossSync.rm_aio(await it.__anext__()) t = c.timestamp_micros cl = c.labels @@ -306,7 +306,7 @@ async def merge_rows( if c.commit_row: yield Row(row_key, cells) break - c = await it.__anext__() + c = CrossSync.rm_aio(await it.__anext__()) except _ResetRow as e: c = e.chunk if ( diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 66ec7a646..f18b46256 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -257,10 +257,12 @@ async def close(self, timeout: float | None = 2.0): self._is_closed.set() for task in self._channel_refresh_tasks: task.cancel() - await self.transport.close() + CrossSync.rm_aio(await self.transport.close()) if self._executor: self._executor.shutdown(wait=False) - await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) + CrossSync.rm_aio( + await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) + ) self._channel_refresh_tasks = [] @CrossSync.convert @@ -300,8 +302,10 @@ async def _ping_and_warm_instances( ) for (instance_name, table_name, app_profile_id) in instance_list ] - result_list = await CrossSync.gather_partials( - partial_list, return_exceptions=True, sync_executor=self._executor + result_list = CrossSync.rm_aio( + await CrossSync.gather_partials( + partial_list, return_exceptions=True, sync_executor=self._executor + ) ) return [r or None for r in result_list] @@ -339,27 +343,31 @@ async def _manage_channel( if next_sleep > 0: # warm the current channel immediately channel = self.transport.channels[channel_idx] - await self._ping_and_warm_instances(channel) + CrossSync.rm_aio(await self._ping_and_warm_instances(channel)) # continuously refresh the channel every `refresh_interval` seconds while not self._is_closed.is_set(): - await CrossSync.event_wait( - self._is_closed, - next_sleep, - async_break_early=False, # no need to interrupt sleep. Task will be cancelled on close + CrossSync.rm_aio( + await CrossSync.event_wait( + self._is_closed, + next_sleep, + async_break_early=False, # no need to interrupt sleep. Task will be cancelled on close + ) ) if self._is_closed.is_set(): # don't refresh if client is closed break # prepare new channel for use new_channel = self.transport.grpc_channel._create_channel() - await self._ping_and_warm_instances(new_channel) + CrossSync.rm_aio(await self._ping_and_warm_instances(new_channel)) # cycle channel out of use, with long grace window before closure start_timestamp = time.monotonic() - await self.transport.replace_channel( - channel_idx, - grace=grace_period, - new_channel=new_channel, - event=self._is_closed, + CrossSync.rm_aio( + await self.transport.replace_channel( + channel_idx, + grace=grace_period, + new_channel=new_channel, + event=self._is_closed, + ) ) # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) @@ -391,7 +399,9 @@ async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: # refresh tasks already running # call ping and warm on all existing channels for channel in self.transport.channels: - await self._ping_and_warm_instances(channel, instance_key) + CrossSync.rm_aio( + await self._ping_and_warm_instances(channel, instance_key) + ) else: # refresh tasks aren't active. start them as background tasks self._start_background_channel_refresh() @@ -476,8 +486,8 @@ async def __aenter__(self): @CrossSync.convert(sync_name="__exit__", replace_symbols={"__aexit__": "__exit__"}) async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) + CrossSync.rm_aio(await self.close()) + CrossSync.rm_aio(await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb)) @CrossSync.export_sync( @@ -705,13 +715,15 @@ async def read_rows( from any retries that failed google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error """ - row_generator = await self.read_rows_stream( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, + row_generator = CrossSync.rm_aio( + await self.read_rows_stream( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) ) - return [row async for row in row_generator] + return CrossSync.rm_aio([row async for row in row_generator]) @CrossSync.convert async def read_row( @@ -753,11 +765,13 @@ async def read_row( if row_key is None: raise ValueError("row_key must be string or bytes") query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) - results = await self.read_rows( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, + results = CrossSync.rm_aio( + await self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) ) if len(results) == 0: return None @@ -817,25 +831,31 @@ async def read_rows_sharded( concurrency_sem = CrossSync.Semaphore(_CONCURRENCY_LIMIT) async def read_rows_with_semaphore(query): - async with concurrency_sem: + async with CrossSync.rm_aio(concurrency_sem): # calculate new timeout based on time left in overall operation shard_timeout = next(rpc_timeout_generator) if shard_timeout <= 0: raise DeadlineExceeded( "Operation timeout exceeded before starting query" ) - return await self.read_rows( - query, - operation_timeout=shard_timeout, - attempt_timeout=min(attempt_timeout, shard_timeout), - retryable_errors=retryable_errors, + return CrossSync.rm_aio( + await self.read_rows( + query, + operation_timeout=shard_timeout, + attempt_timeout=min(attempt_timeout, shard_timeout), + retryable_errors=retryable_errors, + ) ) routine_list = [ partial(read_rows_with_semaphore, query) for query in sharded_query ] - batch_result = await CrossSync.gather_partials( - routine_list, return_exceptions=True, sync_executor=self.client._executor + batch_result = CrossSync.rm_aio( + await CrossSync.gather_partials( + routine_list, + return_exceptions=True, + sync_executor=self.client._executor, + ) ) # collect results and errors @@ -904,11 +924,13 @@ async def row_exists( limit_filter = CellsRowLimitFilter(1) chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) - results = await self.read_rows( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, + results = CrossSync.rm_aio( + await self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) ) return len(results) > 0 @@ -968,21 +990,27 @@ async def sample_row_keys( metadata = _make_metadata(self.table_name, self.app_profile_id) async def execute_rpc(): - results = await self.client._gapic_client.sample_row_keys( - table_name=self.table_name, - app_profile_id=self.app_profile_id, - timeout=next(attempt_timeout_gen), - metadata=metadata, - retry=None, + results = CrossSync.rm_aio( + await self.client._gapic_client.sample_row_keys( + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=next(attempt_timeout_gen), + metadata=metadata, + retry=None, + ) + ) + return CrossSync.rm_aio( + [(s.row_key, s.offset_bytes) async for s in results] ) - return [(s.row_key, s.offset_bytes) async for s in results] - return await CrossSync.retry_target( - execute_rpc, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, + return CrossSync.rm_aio( + await CrossSync.retry_target( + execute_rpc, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) ) @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) @@ -1106,12 +1134,14 @@ async def mutate_row( metadata=_make_metadata(self.table_name, self.app_profile_id), retry=None, ) - return await CrossSync.retry_target( - target, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, + return CrossSync.rm_aio( + await CrossSync.retry_target( + target, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) ) @CrossSync.convert @@ -1168,7 +1198,7 @@ async def bulk_mutate_rows( attempt_timeout, retryable_exceptions=retryable_excs, ) - await operation.start() + CrossSync.rm_aio(await operation.start()) @CrossSync.convert async def check_and_mutate_row( @@ -1224,16 +1254,20 @@ async def check_and_mutate_row( false_case_mutations = [false_case_mutations] false_case_list = [m._to_pb() for m in false_case_mutations or []] metadata = _make_metadata(self.table_name, self.app_profile_id) - result = await self.client._gapic_client.check_and_mutate_row( - true_mutations=true_case_list, - false_mutations=false_case_list, - predicate_filter=predicate._to_pb() if predicate is not None else None, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, + result = CrossSync.rm_aio( + await self.client._gapic_client.check_and_mutate_row( + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) ) return result.predicate_matched @@ -1276,14 +1310,18 @@ async def read_modify_write_row( if not rules: raise ValueError("rules must contain at least one item") metadata = _make_metadata(self.table_name, self.app_profile_id) - result = await self.client._gapic_client.read_modify_write_row( - rules=[rule._to_pb() for rule in rules], - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, + result = CrossSync.rm_aio( + await self.client._gapic_client.read_modify_write_row( + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) ) # construct Row from result return Row._from_pb(result.row) @@ -1295,7 +1333,9 @@ async def close(self): """ if self._register_instance_future: self._register_instance_future.cancel() - await self.client._remove_instance_registration(self.instance_id, self) + CrossSync.rm_aio( + await self.client._remove_instance_registration(self.instance_id, self) + ) @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): @@ -1306,7 +1346,7 @@ async def __aenter__(self): grpc channels will be warmed for the specified instance """ if self._register_instance_future: - await self._register_instance_future + CrossSync.rm_aio(await self._register_instance_future) return self @CrossSync.convert(sync_name="__exit__") @@ -1317,4 +1357,4 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): Unregister this instance with the client, so that grpc channels will no longer be warmed """ - await self.close() + CrossSync.rm_aio(await self.close()) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index d2d77d3a1..7a6def9e4 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -120,7 +120,7 @@ async def remove_from_flow( self._in_flight_mutation_count -= total_count self._in_flight_mutation_bytes -= total_size # notify any blocked requests that there is additional capacity - async with self._capacity_condition: + async with CrossSync.rm_aio(self._capacity_condition): self._capacity_condition.notify_all() @CrossSync.convert @@ -146,7 +146,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] start_idx = end_idx batch_mutation_count = 0 # fill up batch until we hit capacity - async with self._capacity_condition: + async with CrossSync.rm_aio(self._capacity_condition): while end_idx < len(mutations): next_entry = mutations[end_idx] next_size = next_entry.size() @@ -167,8 +167,10 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] break else: # batch is empty. Block until we have capacity - await self._capacity_condition.wait_for( - lambda: self._has_capacity(next_count, next_size) + CrossSync.rm_aio( + await self._capacity_condition.wait_for( + lambda: self._has_capacity(next_count, next_size) + ) ) yield mutations[start_idx:end_idx] @@ -280,8 +282,10 @@ async def _timer_routine(self, interval: float | None) -> None: return None while not self._closed.is_set(): # wait until interval has passed, or until closed - await CrossSync.event_wait( - self._closed, timeout=interval, async_break_early=False + CrossSync.rm_aio( + await CrossSync.event_wait( + self._closed, timeout=interval, async_break_early=False + ) ) if not self._closed.is_set() and self._staged_entries: self._schedule_flush() @@ -314,7 +318,7 @@ async def append(self, mutation_entry: RowMutationEntry): ): self._schedule_flush() # yield to the event loop to allow flush to run - await CrossSync.yield_to_event_loop() + CrossSync.rm_aio(await CrossSync.yield_to_event_loop()) def _schedule_flush(self) -> CrossSync.Future[None] | None: """ @@ -346,13 +350,17 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): """ # flush new entries in_process_requests: list[CrossSync.Future[list[FailedMutationEntryError]]] = [] - async for batch in self._flow_control.add_to_flow(new_entries): + async for batch in CrossSync.rm_aio( + self._flow_control.add_to_flow(new_entries) + ): batch_task = CrossSync.create_task( self._execute_mutate_rows, batch, sync_executor=self._sync_executor ) in_process_requests.append(batch_task) # wait for all inflight requests to complete - found_exceptions = await self._wait_for_batch_results(*in_process_requests) + found_exceptions = CrossSync.rm_aio( + await self._wait_for_batch_results(*in_process_requests) + ) # update exception data to reflect any new errors self._entries_processed_since_last_raise += len(new_entries) self._add_exceptions(found_exceptions) @@ -382,7 +390,7 @@ async def _execute_mutate_rows( attempt_timeout=self._attempt_timeout, retryable_exceptions=self._retryable_errors, ) - await operation.start() + CrossSync.rm_aio(await operation.start()) except MutationsExceptionGroup as e: # strip index information from exceptions, since it is not useful in a batch context for subexc in e.exceptions: @@ -390,7 +398,7 @@ async def _execute_mutate_rows( return list(e.exceptions) finally: # mark batch as complete in flow control - await self._flow_control.remove_from_flow(batch) + CrossSync.rm_aio(await self._flow_control.remove_from_flow(batch)) return [] def _add_exceptions(self, excs: list[Exception]): @@ -450,7 +458,7 @@ async def __aexit__(self, exc_type, exc, tb): Flushes the batcher and cleans up resources. """ - await self.close() + CrossSync.rm_aio(await self.close()) @property def closed(self) -> bool: @@ -468,7 +476,7 @@ async def close(self): self._closed.set() self._flush_timer.cancel() self._schedule_flush() - await CrossSync.wait([*self._flush_jobs, self._flush_timer]) + CrossSync.rm_aio(await CrossSync.wait([*self._flush_jobs, self._flush_timer])) # shut down executor if self._sync_executor: with self._sync_executor: @@ -512,7 +520,7 @@ async def _wait_for_batch_results( for task in tasks: if CrossSync.is_async: # futures don't need to be awaited in sync mode - await task + CrossSync.rm_aio(await task) try: exc_list = task.result() if exc_list: diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 6dcbd22fb..0af5f0c4a 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -222,6 +222,10 @@ def verify_async_event_loop() -> None: """ asyncio.get_running_loop() + @staticmethod + def rm_aio(statement: Any) -> Any: + return statement + class _Sync_Impl: """ Provide sync versions of the async functions and types in CrossSync diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py index 3ce8a8d16..76350f443 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -236,6 +236,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): ) ) # convert class contents + wrapped_node = transformers_globals["RmAioFunctions"]().visit(wrapped_node) replace_dict = self.replace_symbols or {} replace_dict.update({"CrossSync": f"CrossSync._Sync_Impl"}) wrapped_node = transformers_globals["SymbolReplacer"](replace_dict).visit( @@ -273,11 +274,17 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): Transform async method into sync """ import ast + # replace async function with sync function wrapped_node = ast.copy_location( - ast.FunctionDef(wrapped_node.name, wrapped_node.args, - wrapped_node.body, wrapped_node.decorator_list, wrapped_node.returns, - ), wrapped_node, + ast.FunctionDef( + wrapped_node.name, + wrapped_node.args, + wrapped_node.body, + wrapped_node.decorator_list, + wrapped_node.returns, + ), + wrapped_node, ) # update name if specified if self.sync_name: @@ -318,6 +325,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): convert async to sync """ import ast + converted = transformers_globals["AsyncToSync"]().visit(wrapped_node) return converted From 37b4833736a5e6fc8d2baf96312f694602a922a7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 17 Jul 2024 17:36:33 -0700 Subject: [PATCH 192/360] reverted client and test changes --- google/cloud/bigtable/data/__init__.py | 2 +- .../bigtable/data/_async/_mutate_rows.py | 49 +- .../cloud/bigtable/data/_async/_read_rows.py | 53 +- google/cloud/bigtable/data/_async/client.py | 374 +++---- .../bigtable/data/_async/mutations_batcher.py | 212 ++-- google/cloud/bigtable/data/_helpers.py | 3 - google/cloud/bigtable/data/exceptions.py | 15 - google/cloud/bigtable/data/mutations.py | 12 - .../transports/pooled_grpc_asyncio.py | 27 +- tests/system/data/__init__.py | 3 - tests/system/data/setup_fixtures.py | 25 + tests/system/data/test_system.py | 942 +++++++++++++++++ tests/system/data/test_system_async.py | 992 ------------------ tests/unit/data/_async/__init__.py | 0 tests/unit/data/_async/test__mutate_rows.py | 110 +- tests/unit/data/_async/test__read_rows.py | 75 +- tests/unit/data/_async/test_client.py | 900 +++++++--------- .../data/_async/test_mutations_batcher.py | 782 +++++++------- .../data/_async/test_read_rows_acceptance.py | 351 ------- tests/unit/data/test_read_rows_acceptance.py | 331 ++++++ 20 files changed, 2445 insertions(+), 2813 deletions(-) create mode 100644 tests/system/data/test_system.py delete mode 100644 tests/system/data/test_system_async.py delete mode 100644 tests/unit/data/_async/__init__.py delete mode 100644 tests/unit/data/_async/test_read_rows_acceptance.py create mode 100644 tests/unit/data/test_read_rows_acceptance.py diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index 66fe3479b..5229f8021 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -50,10 +50,10 @@ __all__ = ( "BigtableDataClientAsync", "TableAsync", - "MutationsBatcherAsync", "RowKeySamples", "ReadRowsQuery", "RowRange", + "MutationsBatcherAsync", "Mutation", "RowMutationEntry", "SetCell", diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index e62d43397..99b9944cd 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -15,10 +15,12 @@ from __future__ import annotations from typing import Sequence, TYPE_CHECKING +from dataclasses import dataclass import functools from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries +import google.cloud.bigtable_v2.types.bigtable as types_pb import google.cloud.bigtable.data.exceptions as bt_exceptions from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _attempt_timeout_generator @@ -26,25 +28,25 @@ # mutate_rows requests are limited to this number of mutations from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT -from google.cloud.bigtable.data.mutations import _EntryWithProto - -from google.cloud.bigtable.data._sync.cross_sync import CrossSync if TYPE_CHECKING: + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) from google.cloud.bigtable.data.mutations import RowMutationEntry + from google.cloud.bigtable.data._async.client import TableAsync - if CrossSync.is_async: - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) - CrossSync.add_mapping("GapicClient", BigtableAsyncClient) +@dataclass +class _EntryWithProto: + """ + A dataclass to hold a RowMutationEntry and its corresponding proto representation. + """ + + entry: RowMutationEntry + proto: types_pb.MutateRowsRequest.Entry -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", - add_mapping_for_name="_MutateRowsOperation", -) class _MutateRowsOperationAsync: """ MutateRowsOperation manages the logic of sending a set of row mutations, @@ -64,11 +66,10 @@ class _MutateRowsOperationAsync: If not specified, the request will run until operation_timeout is reached. """ - @CrossSync.convert def __init__( self, - gapic_client: "CrossSync.GapicClient", - table: "CrossSync.Table", + gapic_client: "BigtableAsyncClient", + table: "TableAsync", mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, @@ -99,7 +100,7 @@ def __init__( bt_exceptions._MutateRowsIncomplete, ) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - self._operation = lambda: CrossSync.retry_target( + self._operation = retries.retry_target_async( self._run_attempt, self.is_retryable, sleep_generator, @@ -114,7 +115,6 @@ def __init__( self.remaining_indices = list(range(len(self.mutations))) self.errors: dict[int, list[Exception]] = {} - @CrossSync.convert async def start(self): """ Start the operation, and run until completion @@ -124,7 +124,7 @@ async def start(self): """ try: # trigger mutate_rows - CrossSync.rm_aio(await self._operation()) + await self._operation except Exception as exc: # exceptions raised by retryable are added to the list of exceptions for all unfinalized mutations incomplete_indices = self.remaining_indices.copy() @@ -151,7 +151,6 @@ async def start(self): all_errors, len(self.mutations) ) - @CrossSync.convert async def _run_attempt(self): """ Run a single attempt of the mutate_rows rpc. @@ -172,14 +171,12 @@ async def _run_attempt(self): return # make gapic request try: - result_generator = CrossSync.rm_aio( - await self._gapic_fn( - timeout=next(self.timeout_generator), - entries=request_entries, - retry=None, - ) + result_generator = await self._gapic_fn( + timeout=next(self.timeout_generator), + entries=request_entries, + retry=None, ) - async for result_list in CrossSync.rm_aio(result_generator): + async for result_list in result_generator: for result in result_list.entries: # convert sub-request index to global index orig_idx = active_request_indices[result.index] diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 2fe48e9e9..78cb7a991 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -15,7 +15,13 @@ from __future__ import annotations -from typing import Sequence +from typing import ( + TYPE_CHECKING, + AsyncGenerator, + AsyncIterable, + Awaitable, + Sequence, +) from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB @@ -26,7 +32,6 @@ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _RowSetComplete -from google.cloud.bigtable.data.exceptions import _ResetRow from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory @@ -34,13 +39,15 @@ from google.api_core import retry as retries from google.api_core.retry import exponential_sleep_generator -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if TYPE_CHECKING: + from google.cloud.bigtable.data._async.client import TableAsync + + +class _ResetRow(Exception): + def __init__(self, chunk): + self.chunk = chunk -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", - add_mapping_for_name="_ReadRowsOperation", -) class _ReadRowsOperationAsync: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream @@ -75,7 +82,7 @@ class _ReadRowsOperationAsync: def __init__( self, query: ReadRowsQuery, - table: "CrossSync.Table", + table: "TableAsync", operation_timeout: float, attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), @@ -101,14 +108,14 @@ def __init__( self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None - def start_operation(self) -> CrossSync.Iterable[Row]: + def start_operation(self) -> AsyncGenerator[Row, None]: """ Start the read_rows operation, retrying on retryable errors. Yields: Row: The next row in the stream """ - return CrossSync.retry_target_stream( + return retries.retry_target_stream_async( self._read_rows_attempt, self._predicate, exponential_sleep_generator(0.01, 60, multiplier=2), @@ -116,7 +123,7 @@ def start_operation(self) -> CrossSync.Iterable[Row]: exception_factory=_retry_exception_factory, ) - def _read_rows_attempt(self) -> CrossSync.Iterable[Row]: + def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: """ Attempt a single read_rows rpc call. This function is intended to be wrapped by retry logic, @@ -152,10 +159,9 @@ def _read_rows_attempt(self) -> CrossSync.Iterable[Row]: chunked_stream = self.chunk_stream(gapic_stream) return self.merge_rows(chunked_stream) - @CrossSync.convert async def chunk_stream( - self, stream: CrossSync.Awaitable[CrossSync.Iterable[ReadRowsResponsePB]] - ) -> CrossSync.Iterable[ReadRowsResponsePB.CellChunk]: + self, stream: Awaitable[AsyncIterable[ReadRowsResponsePB]] + ) -> AsyncGenerator[ReadRowsResponsePB.CellChunk, None]: """ process chunks out of raw read_rows stream @@ -164,7 +170,7 @@ async def chunk_stream( Yields: ReadRowsResponsePB.CellChunk: the next chunk in the stream """ - async for resp in CrossSync.rm_aio(await stream): + async for resp in await stream: # extract proto from proto-plus wrapper resp = resp._pb @@ -205,12 +211,9 @@ async def chunk_stream( current_key = None @staticmethod - @CrossSync.convert( - replace_symbols={"__aiter__": "__iter__", "__anext__": "__next__"} - ) async def merge_rows( - chunks: CrossSync.Iterable[ReadRowsResponsePB.CellChunk] | None, - ) -> CrossSync.Iterable[Row]: + chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None + ) -> AsyncGenerator[Row, None]: """ Merge chunks into rows @@ -225,8 +228,8 @@ async def merge_rows( # For each row while True: try: - c = CrossSync.rm_aio(await it.__anext__()) - except CrossSync.StopIteration: + c = await it.__anext__() + except StopAsyncIteration: # stream complete return row_key = c.row_key @@ -274,7 +277,7 @@ async def merge_rows( buffer = [value] while c.value_size > 0: # throws when premature end - c = CrossSync.rm_aio(await it.__anext__()) + c = await it.__anext__() t = c.timestamp_micros cl = c.labels @@ -306,7 +309,7 @@ async def merge_rows( if c.commit_row: yield Row(row_key, cells) break - c = CrossSync.rm_aio(await it.__anext__()) + c = await it.__anext__() except _ResetRow as e: c = e.chunk if ( @@ -319,7 +322,7 @@ async def merge_rows( ): raise InvalidChunk("reset row with data") continue - except CrossSync.StopIteration: + except StopAsyncIteration: raise InvalidChunk("premature end of stream") @staticmethod diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index f18b46256..34fdf847a 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -25,18 +25,22 @@ TYPE_CHECKING, ) +import asyncio +import grpc import time import warnings +import sys import random import os -import concurrent.futures from functools import partial -from grpc import Channel from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.transports.base import ( - DEFAULT_CLIENT_INFO, +from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient +from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, + PooledChannel, ) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest from google.cloud.client import ClientWithProject @@ -45,6 +49,7 @@ from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import ServiceUnavailable from google.api_core.exceptions import Aborted +from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync import google.auth.credentials import google.auth._default @@ -55,6 +60,8 @@ from google.cloud.bigtable.data.exceptions import FailedQueryShardError from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup +from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry +from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import _WarmedInstanceKey from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT @@ -64,53 +71,21 @@ from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import _attempt_timeout_generator -from google.cloud.bigtable.data._helpers import _MB_SIZE -from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry - +from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync +from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule from google.cloud.bigtable.data.row_filters import RowFilter from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter from google.cloud.bigtable.data.row_filters import RowFilterChain -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -if CrossSync.is_async: - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel as AsyncPooledChannel, - ) - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) - - # define file-specific cross-sync replacements - CrossSync.add_mapping("GapicClient", BigtableAsyncClient) - CrossSync.add_mapping("PooledTransport", PooledBigtableGrpcAsyncIOTransport) - CrossSync.add_mapping("PooledChannel", AsyncPooledChannel) - CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) - CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) - CrossSync.add_mapping("MutationsBatcher", MutationsBatcherAsync) - if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync.client.BigtableDataClient", - add_mapping_for_name="DataClient", -) class BigtableDataClientAsync(ClientWithProject): - @CrossSync.convert def __init__( self, *, @@ -145,8 +120,8 @@ def __init__( ValueError: if pool_size is less than 1 """ # set up transport in registry - transport_str = f"bt-{self._client_version()}-{pool_size}" - transport = CrossSync.PooledTransport.with_fixed_size(pool_size) + transport_str = f"pooled_grpc_asyncio_{pool_size}" + transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) BigtableClientMeta._transport_registry[transport_str] = transport # set up client info headers for veneer library client_info = DEFAULT_CLIENT_INFO @@ -171,24 +146,22 @@ def __init__( project=project, client_options=client_options, ) - self._gapic_client = CrossSync.GapicClient( + self._gapic_client = BigtableAsyncClient( transport=transport_str, credentials=credentials, client_options=client_options, client_info=client_info, ) - self._is_closed = CrossSync.Event() - self.transport = cast(CrossSync.PooledTransport, self._gapic_client.transport) + self.transport = cast( + PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport + ) # keep track of active instances to for warmup on channel refresh self._active_instances: Set[_WarmedInstanceKey] = set() # keep track of table objects associated with each instance # only remove instance from _active_instances when all associated tables remove it self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() - self._channel_refresh_tasks: list[CrossSync.Task[None]] = [] - self._executor = ( - concurrent.futures.ThreadPoolExecutor() if not CrossSync.is_async else None - ) + self._channel_refresh_tasks: list[asyncio.Task[None]] = [] if self._emulator_host is not None: # connect to an emulator host warnings.warn( @@ -196,7 +169,7 @@ def __init__( RuntimeWarning, stacklevel=2, ) - self.transport._grpc_channel = CrossSync.PooledChannel( + self.transport._grpc_channel = PooledChannel( pool_size=pool_size, host=self._emulator_host, insecure=True, @@ -221,10 +194,7 @@ def _client_version() -> str: """ Helper function to return the client version string for this client """ - version_str = f"{google.cloud.bigtable.__version__}-data" - if CrossSync.is_async: - version_str += "-async" - return version_str + return f"{google.cloud.bigtable.__version__}-data-async" def _start_background_channel_refresh(self) -> None: """ @@ -233,41 +203,31 @@ def _start_background_channel_refresh(self) -> None: Raises: RuntimeError: if not called in an asyncio event loop """ - if ( - not self._channel_refresh_tasks - and not self._emulator_host - and not self._is_closed.is_set() - ): - # raise error if not in an event loop in async client - CrossSync.verify_async_event_loop() + if not self._channel_refresh_tasks and not self._emulator_host: + # raise RuntimeError if there is no event loop + asyncio.get_running_loop() for channel_idx in range(self.transport.pool_size): - refresh_task = CrossSync.create_task( - self._manage_channel, - channel_idx, - sync_executor=self._executor, - task_name=f"{self.__class__.__name__} channel refresh {channel_idx}", - ) + refresh_task = asyncio.create_task(self._manage_channel(channel_idx)) + if sys.version_info >= (3, 8): + # task names supported in Python 3.8+ + refresh_task.set_name( + f"{self.__class__.__name__} channel refresh {channel_idx}" + ) self._channel_refresh_tasks.append(refresh_task) - @CrossSync.convert - async def close(self, timeout: float | None = 2.0): + async def close(self, timeout: float = 2.0): """ Cancel all background tasks """ - self._is_closed.set() for task in self._channel_refresh_tasks: task.cancel() - CrossSync.rm_aio(await self.transport.close()) - if self._executor: - self._executor.shutdown(wait=False) - CrossSync.rm_aio( - await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) - ) + group = asyncio.gather(*self._channel_refresh_tasks, return_exceptions=True) + await asyncio.wait_for(group, timeout=timeout) + await self.transport.close() self._channel_refresh_tasks = [] - @CrossSync.convert async def _ping_and_warm_instances( - self, channel: Channel, instance_key: _WarmedInstanceKey | None = None + self, channel: grpc.aio.Channel, instance_key: _WarmedInstanceKey | None = None ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel @@ -288,9 +248,8 @@ async def _ping_and_warm_instances( request_serializer=PingAndWarmRequest.serialize, ) # prepare list of coroutines to run - partial_list = [ - partial( - ping_rpc, + tasks = [ + ping_rpc( request={"name": instance_name, "app_profile_id": app_profile_id}, metadata=[ ( @@ -302,14 +261,11 @@ async def _ping_and_warm_instances( ) for (instance_name, table_name, app_profile_id) in instance_list ] - result_list = CrossSync.rm_aio( - await CrossSync.gather_partials( - partial_list, return_exceptions=True, sync_executor=self._executor - ) - ) + # execute coroutines in parallel + result_list = await asyncio.gather(*tasks, return_exceptions=True) + # return None in place of empty successful responses return [r or None for r in result_list] - @CrossSync.convert async def _manage_channel( self, channel_idx: int, @@ -343,37 +299,22 @@ async def _manage_channel( if next_sleep > 0: # warm the current channel immediately channel = self.transport.channels[channel_idx] - CrossSync.rm_aio(await self._ping_and_warm_instances(channel)) + await self._ping_and_warm_instances(channel) # continuously refresh the channel every `refresh_interval` seconds - while not self._is_closed.is_set(): - CrossSync.rm_aio( - await CrossSync.event_wait( - self._is_closed, - next_sleep, - async_break_early=False, # no need to interrupt sleep. Task will be cancelled on close - ) - ) - if self._is_closed.is_set(): - # don't refresh if client is closed - break + while True: + await asyncio.sleep(next_sleep) # prepare new channel for use new_channel = self.transport.grpc_channel._create_channel() - CrossSync.rm_aio(await self._ping_and_warm_instances(new_channel)) + await self._ping_and_warm_instances(new_channel) # cycle channel out of use, with long grace window before closure - start_timestamp = time.monotonic() - CrossSync.rm_aio( - await self.transport.replace_channel( - channel_idx, - grace=grace_period, - new_channel=new_channel, - event=self._is_closed, - ) + start_timestamp = time.time() + await self.transport.replace_channel( + channel_idx, grace=grace_period, swap_sleep=10, new_channel=new_channel ) # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) - next_sleep = next_refresh - (time.monotonic() - start_timestamp) + next_sleep = next_refresh - (time.time() - start_timestamp) - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: """ Registers an instance with the client, and warms the channel pool @@ -399,14 +340,11 @@ async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: # refresh tasks already running # call ping and warm on all existing channels for channel in self.transport.channels: - CrossSync.rm_aio( - await self._ping_and_warm_instances(channel, instance_key) - ) + await self._ping_and_warm_instances(channel, instance_key) else: # refresh tasks aren't active. start them as background tasks self._start_background_channel_refresh() - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def _remove_instance_registration( self, instance_id: str, owner: TableAsync ) -> bool: @@ -437,7 +375,6 @@ async def _remove_instance_registration( except KeyError: return False - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: """ Returns a table instance for making data API requests. All arguments are passed @@ -479,20 +416,15 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs """ return TableAsync(self, instance_id, table_id, *args, **kwargs) - @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): self._start_background_channel_refresh() return self - @CrossSync.convert(sync_name="__exit__", replace_symbols={"__aexit__": "__exit__"}) async def __aexit__(self, exc_type, exc_val, exc_tb): - CrossSync.rm_aio(await self.close()) - CrossSync.rm_aio(await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb)) + await self.close() + await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync.client.Table", add_mapping_for_name="Table" -) class TableAsync: """ Main Data API surface @@ -501,9 +433,6 @@ class TableAsync: each call """ - @CrossSync.convert( - replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} - ) def __init__( self, client: BigtableDataClientAsync, @@ -612,19 +541,17 @@ def __init__( default_mutate_rows_retryable_errors or () ) self.default_retryable_errors = default_retryable_errors or () + + # raises RuntimeError if called outside of an async context (no running event loop) try: - self._register_instance_future = CrossSync.create_task( - self.client._register_instance, - self.instance_id, - self, - sync_executor=self.client._executor, + self._register_instance_task = asyncio.create_task( + self.client._register_instance(instance_id, self) ) except RuntimeError as e: raise RuntimeError( f"{self.__class__.__name__} must be created within an async event loop context." ) from e - @CrossSync.convert(replace_symbols={"AsyncIterable": "Iterable"}) async def read_rows_stream( self, query: ReadRowsQuery, @@ -666,7 +593,7 @@ async def read_rows_stream( ) retryable_excs = _get_retryable_errors(retryable_errors, self) - row_merger = CrossSync._ReadRowsOperation( + row_merger = _ReadRowsOperationAsync( query, self, operation_timeout=operation_timeout, @@ -675,7 +602,6 @@ async def read_rows_stream( ) return row_merger.start_operation() - @CrossSync.convert async def read_rows( self, query: ReadRowsQuery, @@ -715,17 +641,14 @@ async def read_rows( from any retries that failed google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error """ - row_generator = CrossSync.rm_aio( - await self.read_rows_stream( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, - ) + row_generator = await self.read_rows_stream( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, ) - return CrossSync.rm_aio([row async for row in row_generator]) + return [row async for row in row_generator] - @CrossSync.convert async def read_row( self, row_key: str | bytes, @@ -765,19 +688,16 @@ async def read_row( if row_key is None: raise ValueError("row_key must be string or bytes") query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) - results = CrossSync.rm_aio( - await self.read_rows( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, - ) + results = await self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, ) if len(results) == 0: return None return results[0] - @CrossSync.convert async def read_rows_sharded( self, sharded_query: ShardedQuery, @@ -828,35 +748,25 @@ async def read_rows_sharded( ) # limit the number of concurrent requests using a semaphore - concurrency_sem = CrossSync.Semaphore(_CONCURRENCY_LIMIT) + concurrency_sem = asyncio.Semaphore(_CONCURRENCY_LIMIT) async def read_rows_with_semaphore(query): - async with CrossSync.rm_aio(concurrency_sem): + async with concurrency_sem: # calculate new timeout based on time left in overall operation shard_timeout = next(rpc_timeout_generator) if shard_timeout <= 0: raise DeadlineExceeded( "Operation timeout exceeded before starting query" ) - return CrossSync.rm_aio( - await self.read_rows( - query, - operation_timeout=shard_timeout, - attempt_timeout=min(attempt_timeout, shard_timeout), - retryable_errors=retryable_errors, - ) + return await self.read_rows( + query, + operation_timeout=shard_timeout, + attempt_timeout=min(attempt_timeout, shard_timeout), + retryable_errors=retryable_errors, ) - routine_list = [ - partial(read_rows_with_semaphore, query) for query in sharded_query - ] - batch_result = CrossSync.rm_aio( - await CrossSync.gather_partials( - routine_list, - return_exceptions=True, - sync_executor=self.client._executor, - ) - ) + routine_list = [read_rows_with_semaphore(query) for query in sharded_query] + batch_result = await asyncio.gather(*routine_list, return_exceptions=True) # collect results and errors error_dict = {} @@ -883,7 +793,6 @@ async def read_rows_with_semaphore(query): ) return results_list - @CrossSync.convert async def row_exists( self, row_key: str | bytes, @@ -924,17 +833,14 @@ async def row_exists( limit_filter = CellsRowLimitFilter(1) chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) - results = CrossSync.rm_aio( - await self.read_rows( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, - ) + results = await self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, ) return len(results) > 0 - @CrossSync.convert async def sample_row_keys( self, *, @@ -990,30 +896,23 @@ async def sample_row_keys( metadata = _make_metadata(self.table_name, self.app_profile_id) async def execute_rpc(): - results = CrossSync.rm_aio( - await self.client._gapic_client.sample_row_keys( - table_name=self.table_name, - app_profile_id=self.app_profile_id, - timeout=next(attempt_timeout_gen), - metadata=metadata, - retry=None, - ) - ) - return CrossSync.rm_aio( - [(s.row_key, s.offset_bytes) async for s in results] + results = await self.client._gapic_client.sample_row_keys( + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=next(attempt_timeout_gen), + metadata=metadata, + retry=None, ) + return [(s.row_key, s.offset_bytes) async for s in results] - return CrossSync.rm_aio( - await CrossSync.retry_target( - execute_rpc, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, - ) + return await retries.retry_target_async( + execute_rpc, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, ) - @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) def mutations_batcher( self, *, @@ -1051,7 +950,7 @@ def mutations_batcher( Returns: MutationsBatcherAsync: a MutationsBatcherAsync context manager that can batch requests """ - return CrossSync.MutationsBatcher( + return MutationsBatcherAsync( self, flush_interval=flush_interval, flush_limit_mutation_count=flush_limit_mutation_count, @@ -1063,7 +962,6 @@ def mutations_batcher( batch_retryable_errors=batch_retryable_errors, ) - @CrossSync.convert async def mutate_row( self, row_key: str | bytes, @@ -1134,17 +1032,14 @@ async def mutate_row( metadata=_make_metadata(self.table_name, self.app_profile_id), retry=None, ) - return CrossSync.rm_aio( - await CrossSync.retry_target( - target, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, - ) + return await retries.retry_target_async( + target, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, ) - @CrossSync.convert async def bulk_mutate_rows( self, mutation_entries: list[RowMutationEntry], @@ -1190,7 +1085,7 @@ async def bulk_mutate_rows( ) retryable_excs = _get_retryable_errors(retryable_errors, self) - operation = CrossSync._MutateRowsOperation( + operation = _MutateRowsOperationAsync( self.client._gapic_client, self, mutation_entries, @@ -1198,9 +1093,8 @@ async def bulk_mutate_rows( attempt_timeout, retryable_exceptions=retryable_excs, ) - CrossSync.rm_aio(await operation.start()) + await operation.start() - @CrossSync.convert async def check_and_mutate_row( self, row_key: str | bytes, @@ -1254,24 +1148,19 @@ async def check_and_mutate_row( false_case_mutations = [false_case_mutations] false_case_list = [m._to_pb() for m in false_case_mutations or []] metadata = _make_metadata(self.table_name, self.app_profile_id) - result = CrossSync.rm_aio( - await self.client._gapic_client.check_and_mutate_row( - true_mutations=true_case_list, - false_mutations=false_case_list, - predicate_filter=predicate._to_pb() if predicate is not None else None, - row_key=row_key.encode("utf-8") - if isinstance(row_key, str) - else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, - ) + result = await self.client._gapic_client.check_and_mutate_row( + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, ) return result.predicate_matched - @CrossSync.convert async def read_modify_write_row( self, row_key: str | bytes, @@ -1310,34 +1199,25 @@ async def read_modify_write_row( if not rules: raise ValueError("rules must contain at least one item") metadata = _make_metadata(self.table_name, self.app_profile_id) - result = CrossSync.rm_aio( - await self.client._gapic_client.read_modify_write_row( - rules=[rule._to_pb() for rule in rules], - row_key=row_key.encode("utf-8") - if isinstance(row_key, str) - else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, - ) + result = await self.client._gapic_client.read_modify_write_row( + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, ) # construct Row from result return Row._from_pb(result.row) - @CrossSync.convert async def close(self): """ Called to close the Table instance and release any resources held by it. """ - if self._register_instance_future: - self._register_instance_future.cancel() - CrossSync.rm_aio( - await self.client._remove_instance_registration(self.instance_id, self) - ) + self._register_instance_task.cancel() + await self.client._remove_instance_registration(self.instance_id, self) - @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """ Implement async context manager protocol @@ -1345,11 +1225,9 @@ async def __aenter__(self): Ensure registration task has time to run, so that grpc channels will be warmed for the specified instance """ - if self._register_instance_future: - CrossSync.rm_aio(await self._register_instance_future) + await self._register_instance_task return self - @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc_val, exc_tb): """ Implement async context manager protocol @@ -1357,4 +1235,4 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): Unregister this instance with the client, so that grpc channels will no longer be warmed """ - CrossSync.rm_aio(await self.close()) + await self.close() diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 7a6def9e4..76d13f00b 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -14,37 +14,32 @@ # from __future__ import annotations -from typing import Sequence, TYPE_CHECKING +from typing import Any, Sequence, TYPE_CHECKING +import asyncio import atexit import warnings from collections import deque -import concurrent.futures +from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data._helpers import _MB_SIZE -from google.cloud.bigtable.data.mutations import ( +from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync +from google.cloud.bigtable.data._async._mutate_rows import ( _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, ) from google.cloud.bigtable.data.mutations import Mutation -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - if TYPE_CHECKING: - from google.cloud.bigtable.data.mutations import RowMutationEntry + from google.cloud.bigtable.data._async.client import TableAsync - if CrossSync.is_async: - from google.cloud.bigtable.data._async.client import TableAsync +# used to make more readable default values +_MB_SIZE = 1024 * 1024 -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync.mutations_batcher._FlowControl", - add_mapping_for_name="_FlowControl", -) class _FlowControlAsync: """ Manages flow control for batched mutations. Mutations are registered against @@ -75,7 +70,7 @@ def __init__( raise ValueError("max_mutation_count must be greater than 0") if self._max_mutation_bytes < 1: raise ValueError("max_mutation_bytes must be greater than 0") - self._capacity_condition = CrossSync.Condition() + self._capacity_condition = asyncio.Condition() self._in_flight_mutation_count = 0 self._in_flight_mutation_bytes = 0 @@ -101,7 +96,6 @@ def _has_capacity(self, additional_count: int, additional_size: int) -> bool: new_count = self._in_flight_mutation_count + additional_count return new_size <= acceptable_size and new_count <= acceptable_count - @CrossSync.convert async def remove_from_flow( self, mutations: RowMutationEntry | list[RowMutationEntry] ) -> None: @@ -120,10 +114,9 @@ async def remove_from_flow( self._in_flight_mutation_count -= total_count self._in_flight_mutation_bytes -= total_size # notify any blocked requests that there is additional capacity - async with CrossSync.rm_aio(self._capacity_condition): + async with self._capacity_condition: self._capacity_condition.notify_all() - @CrossSync.convert async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): """ Generator function that registers mutations with flow control. As mutations @@ -146,7 +139,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] start_idx = end_idx batch_mutation_count = 0 # fill up batch until we hit capacity - async with CrossSync.rm_aio(self._capacity_condition): + async with self._capacity_condition: while end_idx < len(mutations): next_entry = mutations[end_idx] next_size = next_entry.size() @@ -167,19 +160,12 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] break else: # batch is empty. Block until we have capacity - CrossSync.rm_aio( - await self._capacity_condition.wait_for( - lambda: self._has_capacity(next_count, next_size) - ) + await self._capacity_condition.wait_for( + lambda: self._has_capacity(next_count, next_size) ) yield mutations[start_idx:end_idx] -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", - mypy_ignore=["unreachable"], - add_mapping_for_name="MutationsBatcher", -) class MutationsBatcherAsync: """ Allows users to send batches using context manager API: @@ -211,10 +197,9 @@ class MutationsBatcherAsync: Defaults to the Table's default_mutate_rows_retryable_errors. """ - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def __init__( self, - table: TableAsync, + table: "TableAsync", *, flush_interval: float | None = 5, flush_limit_mutation_count: int | None = 1000, @@ -233,11 +218,11 @@ def __init__( batch_retryable_errors, table ) - self._closed = CrossSync.Event() + self.closed: bool = False self._table = table self._staged_entries: list[RowMutationEntry] = [] self._staged_count, self._staged_bytes = 0, 0 - self._flow_control = CrossSync._FlowControl( + self._flow_control = _FlowControlAsync( flow_control_max_mutation_count, flow_control_max_bytes ) self._flush_limit_bytes = flush_limit_bytes @@ -246,15 +231,8 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - self._sync_executor = ( - concurrent.futures.ThreadPoolExecutor(max_workers=8) - if not CrossSync.is_async - else None - ) - self._flush_timer = CrossSync.create_task( - self._timer_routine, flush_interval, sync_executor=self._sync_executor - ) - self._flush_jobs: set[CrossSync.Future[None]] = set() + self._flush_timer = self._start_flush_timer(flush_interval) + self._flush_jobs: set[asyncio.Future[None]] = set() # MutationExceptionGroup reports number of successful entries along with failures self._entries_processed_since_last_raise: int = 0 self._exceptions_since_last_raise: int = 0 @@ -267,8 +245,7 @@ def __init__( # clean up on program exit atexit.register(self._on_exit) - @CrossSync.convert - async def _timer_routine(self, interval: float | None) -> None: + def _start_flush_timer(self, interval: float | None) -> asyncio.Future[None]: """ Set up a background task to flush the batcher every interval seconds @@ -277,20 +254,27 @@ async def _timer_routine(self, interval: float | None) -> None: Args: flush_interval: Automatically flush every flush_interval seconds. If None, no time-based flushing is performed. + Returns: + asyncio.Future[None]: future representing the background task """ - if not interval or interval <= 0: - return None - while not self._closed.is_set(): - # wait until interval has passed, or until closed - CrossSync.rm_aio( - await CrossSync.event_wait( - self._closed, timeout=interval, async_break_early=False - ) - ) - if not self._closed.is_set() and self._staged_entries: - self._schedule_flush() + if interval is None or self.closed: + empty_future: asyncio.Future[None] = asyncio.Future() + empty_future.set_result(None) + return empty_future + + async def timer_routine(self, interval: float): + """ + Triggers new flush tasks every `interval` seconds + """ + while not self.closed: + await asyncio.sleep(interval) + # add new flush task to list + if not self.closed and self._staged_entries: + self._schedule_flush() + + timer_task = asyncio.create_task(timer_routine(self, interval)) + return timer_task - @CrossSync.convert async def append(self, mutation_entry: RowMutationEntry): """ Add a new set of mutations to the internal queue @@ -302,7 +286,7 @@ async def append(self, mutation_entry: RowMutationEntry): ValueError: if an invalid mutation type is added """ # TODO: return a future to track completion of this entry - if self._closed.is_set(): + if self.closed: raise RuntimeError("Cannot append to closed MutationsBatcher") if isinstance(mutation_entry, Mutation): # type: ignore raise ValueError( @@ -318,29 +302,25 @@ async def append(self, mutation_entry: RowMutationEntry): ): self._schedule_flush() # yield to the event loop to allow flush to run - CrossSync.rm_aio(await CrossSync.yield_to_event_loop()) + await asyncio.sleep(0) - def _schedule_flush(self) -> CrossSync.Future[None] | None: + def _schedule_flush(self) -> asyncio.Future[None] | None: """ Update the flush task to include the latest staged entries Returns: - Future[None] | None: + asyncio.Future[None] | None: future representing the background task, if started """ if self._staged_entries: entries, self._staged_entries = self._staged_entries, [] self._staged_count, self._staged_bytes = 0, 0 - new_task = CrossSync.create_task( - self._flush_internal, entries, sync_executor=self._sync_executor - ) - if not new_task.done(): - self._flush_jobs.add(new_task) - new_task.add_done_callback(self._flush_jobs.remove) + new_task = self._create_bg_task(self._flush_internal, entries) + new_task.add_done_callback(self._flush_jobs.remove) + self._flush_jobs.add(new_task) return new_task return None - @CrossSync.convert async def _flush_internal(self, new_entries: list[RowMutationEntry]): """ Flushes a set of mutations to the server, and updates internal state @@ -349,23 +329,16 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): new_entries list of RowMutationEntry objects to flush """ # flush new entries - in_process_requests: list[CrossSync.Future[list[FailedMutationEntryError]]] = [] - async for batch in CrossSync.rm_aio( - self._flow_control.add_to_flow(new_entries) - ): - batch_task = CrossSync.create_task( - self._execute_mutate_rows, batch, sync_executor=self._sync_executor - ) + in_process_requests: list[asyncio.Future[list[FailedMutationEntryError]]] = [] + async for batch in self._flow_control.add_to_flow(new_entries): + batch_task = self._create_bg_task(self._execute_mutate_rows, batch) in_process_requests.append(batch_task) # wait for all inflight requests to complete - found_exceptions = CrossSync.rm_aio( - await self._wait_for_batch_results(*in_process_requests) - ) + found_exceptions = await self._wait_for_batch_results(*in_process_requests) # update exception data to reflect any new errors self._entries_processed_since_last_raise += len(new_entries) self._add_exceptions(found_exceptions) - @CrossSync.convert async def _execute_mutate_rows( self, batch: list[RowMutationEntry] ) -> list[FailedMutationEntryError]: @@ -382,7 +355,7 @@ async def _execute_mutate_rows( FailedMutationEntryError objects will not contain index information """ try: - operation = CrossSync._MutateRowsOperation( + operation = _MutateRowsOperationAsync( self._table.client._gapic_client, self._table, batch, @@ -390,7 +363,7 @@ async def _execute_mutate_rows( attempt_timeout=self._attempt_timeout, retryable_exceptions=self._retryable_errors, ) - CrossSync.rm_aio(await operation.start()) + await operation.start() except MutationsExceptionGroup as e: # strip index information from exceptions, since it is not useful in a batch context for subexc in e.exceptions: @@ -398,7 +371,7 @@ async def _execute_mutate_rows( return list(e.exceptions) finally: # mark batch as complete in flow control - CrossSync.rm_aio(await self._flow_control.remove_from_flow(batch)) + await self._flow_control.remove_from_flow(batch) return [] def _add_exceptions(self, excs: list[Exception]): @@ -446,41 +419,31 @@ def _raise_exceptions(self): entry_count=entry_count, ) - @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """Allow use of context manager API""" return self - @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc, tb): """ Allow use of context manager API. Flushes the batcher and cleans up resources. """ - CrossSync.rm_aio(await self.close()) - - @property - def closed(self) -> bool: - """ - Returns: - - True if the batcher is closed, False otherwise - """ - return self._closed.is_set() + await self.close() - @CrossSync.convert async def close(self): """ Flush queue and clean up resources """ - self._closed.set() + self.closed = True self._flush_timer.cancel() self._schedule_flush() - CrossSync.rm_aio(await CrossSync.wait([*self._flush_jobs, self._flush_timer])) - # shut down executor - if self._sync_executor: - with self._sync_executor: - self._sync_executor.shutdown(wait=True) + if self._flush_jobs: + await asyncio.gather(*self._flush_jobs, return_exceptions=True) + try: + await self._flush_timer + except asyncio.CancelledError: + pass atexit.unregister(self._on_exit) # raise unreported exceptions self._raise_exceptions() @@ -489,17 +452,32 @@ def _on_exit(self): """ Called when program is exited. Raises warning if unflushed mutations remain """ - if not self._closed.is_set() and self._staged_entries: + if not self.closed and self._staged_entries: warnings.warn( f"MutationsBatcher for table {self._table.table_name} was not closed. " f"{len(self._staged_entries)} Unflushed mutations will not be sent to the server." ) @staticmethod - @CrossSync.convert + def _create_bg_task(func, *args, **kwargs) -> asyncio.Future[Any]: + """ + Create a new background task, and return a future + + This method wraps asyncio to make it easier to maintain subclasses + with different concurrency models. + + Args: + func: function to execute in background task + *args: positional arguments to pass to func + **kwargs: keyword arguments to pass to func + Returns: + asyncio.Future: Future object representing the background task + """ + return asyncio.create_task(func(*args, **kwargs)) + + @staticmethod async def _wait_for_batch_results( - *tasks: CrossSync.Future[list[FailedMutationEntryError]] - | CrossSync.Future[None], + *tasks: asyncio.Future[list[FailedMutationEntryError]] | asyncio.Future[None], ) -> list[Exception]: """ Takes in a list of futures representing _execute_mutate_rows tasks, @@ -516,19 +494,19 @@ async def _wait_for_batch_results( """ if not tasks: return [] - exceptions: list[Exception] = [] - for task in tasks: - if CrossSync.is_async: - # futures don't need to be awaited in sync mode - CrossSync.rm_aio(await task) - try: - exc_list = task.result() - if exc_list: - # expect a list of FailedMutationEntryError objects - for exc in exc_list: - # strip index information - exc.index = None - exceptions.extend(exc_list) - except Exception as e: - exceptions.append(e) - return exceptions + all_results = await asyncio.gather(*tasks, return_exceptions=True) + found_errors = [] + for result in all_results: + if isinstance(result, Exception): + # will receive direct Exception objects if request task fails + found_errors.append(result) + elif isinstance(result, BaseException): + # BaseException not expected from grpc calls. Raise immediately + raise result + elif result: + # completed requests will return a list of FailedMutationEntryError + for e in result: + # strip index information + e.index = None + found_errors.extend(result) + return found_errors diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index a8113cc4a..a8fba9ef1 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -48,9 +48,6 @@ "_WarmedInstanceKey", ["instance_name", "table_name", "app_profile_id"] ) -# used to make more readable default values -_MB_SIZE = 1024 * 1024 - # enum used on method calls when table defaults should be used class TABLE_DEFAULT(enum.Enum): diff --git a/google/cloud/bigtable/data/exceptions.py b/google/cloud/bigtable/data/exceptions.py index 8065ed9d1..8d97640aa 100644 --- a/google/cloud/bigtable/data/exceptions.py +++ b/google/cloud/bigtable/data/exceptions.py @@ -41,21 +41,6 @@ class _RowSetComplete(Exception): pass -class _ResetRow(Exception): # noqa: F811 - """ - Internal exception for _ReadRowsOperation - - Denotes that the server sent a reset_row marker, telling the client to drop - all previous chunks for row_key and re-read from the beginning. - - Args: - chunk: the reset_row chunk - """ - - def __init__(self, chunk): - self.chunk = chunk - - class _MutateRowsIncomplete(RuntimeError): """ Exception raised when a mutate_rows call has unfinished work. diff --git a/google/cloud/bigtable/data/mutations.py b/google/cloud/bigtable/data/mutations.py index 2f4e441ed..335a15e12 100644 --- a/google/cloud/bigtable/data/mutations.py +++ b/google/cloud/bigtable/data/mutations.py @@ -366,15 +366,3 @@ def _from_dict(cls, input_dict: dict[str, Any]) -> RowMutationEntry: Mutation._from_dict(mutation) for mutation in input_dict["mutations"] ], ) - - -@dataclass -class _EntryWithProto: - """ - A dataclass to hold a RowMutationEntry and its corresponding proto representation. - - Used in _MutateRowsOperation to avoid repeated conversion of RowMutationEntry to proto. - """ - - entry: RowMutationEntry - proto: types_pb.MutateRowsRequest.Entry diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py index 864b4ecc2..372e5796d 100644 --- a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py +++ b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py @@ -150,7 +150,7 @@ async def wait_for_state_change(self, last_observed_state): raise NotImplementedError() async def replace_channel( - self, channel_idx, grace=1, new_channel=None, event=None + self, channel_idx, grace=None, swap_sleep=1, new_channel=None ) -> aio.Channel: """ Replaces a channel in the pool with a fresh one. @@ -160,14 +160,13 @@ async def replace_channel( Args: channel_idx(int): the channel index in the pool to replace - grace(Optional[float]): The time to wait for active RPCs to - finish. If a grace period is not specified (by passing None for + grace(Optional[float]): The time to wait until all active RPCs are + finished. If a grace period is not specified (by passing None for grace), all existing RPCs are cancelled immediately. - If event is set at close time, grace is ignored + swap_sleep(Optional[float]): The number of seconds to sleep in between + replacing channels and closing the old one new_channel(grpc.aio.Channel): a new channel to insert into the pool at `channel_idx`. If `None`, a new channel will be created. - event(Optional[threading.Event]): an event to signal when the - replacement should be aborted. If set, grace is ignored. """ if channel_idx >= len(self._pool) or channel_idx < 0: raise ValueError( @@ -177,8 +176,7 @@ async def replace_channel( new_channel = self._create_channel() old_channel = self._pool[channel_idx] self._pool[channel_idx] = new_channel - if event is not None and not event.is_set(): - grace = None + await asyncio.sleep(swap_sleep) await old_channel.close(grace=grace) return new_channel @@ -402,7 +400,7 @@ def channels(self) -> List[grpc.Channel]: return self._grpc_channel._pool async def replace_channel( - self, channel_idx, grace=1, new_channel=None, event=None + self, channel_idx, grace=None, swap_sleep=1, new_channel=None ) -> aio.Channel: """ Replaces a channel in the pool with a fresh one. @@ -412,17 +410,16 @@ async def replace_channel( Args: channel_idx(int): the channel index in the pool to replace - grace(Optional[float]): The time to wait for active RPCs to - finish. If a grace period is not specified (by passing None for + grace(Optional[float]): The time to wait until all active RPCs are + finished. If a grace period is not specified (by passing None for grace), all existing RPCs are cancelled immediately. - If event is set at close time, grace is ignored + swap_sleep(Optional[float]): The number of seconds to sleep in between + replacing channels and closing the old one new_channel(grpc.aio.Channel): a new channel to insert into the pool at `channel_idx`. If `None`, a new channel will be created. - event(Optional[threading.Event]): an event to signal when the - replacement should be aborted. If set, grace is ignored. """ return await self._grpc_channel.replace_channel( - channel_idx=channel_idx, grace=grace, new_channel=new_channel, event=event + channel_idx, grace, swap_sleep, new_channel ) diff --git a/tests/system/data/__init__.py b/tests/system/data/__init__.py index f2952b2cd..89a37dc92 100644 --- a/tests/system/data/__init__.py +++ b/tests/system/data/__init__.py @@ -13,6 +13,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -TEST_FAMILY = "test-family" -TEST_FAMILY_2 = "test-family-2" diff --git a/tests/system/data/setup_fixtures.py b/tests/system/data/setup_fixtures.py index 3b5a0af06..77086b7f3 100644 --- a/tests/system/data/setup_fixtures.py +++ b/tests/system/data/setup_fixtures.py @@ -17,10 +17,20 @@ """ import pytest +import pytest_asyncio import os +import asyncio import uuid +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.stop() + loop.close() + + @pytest.fixture(scope="session") def admin_client(): """ @@ -140,7 +150,22 @@ def table_id( print(f"Table {init_table_id} not found, skipping deletion") +@pytest_asyncio.fixture(scope="session") +async def client(): + from google.cloud.bigtable.data import BigtableDataClientAsync + + project = os.getenv("GOOGLE_CLOUD_PROJECT") or None + async with BigtableDataClientAsync(project=project, pool_size=4) as client: + yield client + + @pytest.fixture(scope="session") def project_id(client): """Returns the project ID from the client.""" yield client.project + + +@pytest_asyncio.fixture(scope="session") +async def table(client, table_id, instance_id): + async with client.get_table(instance_id, table_id) as table: + yield table diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py new file mode 100644 index 000000000..9fe208551 --- /dev/null +++ b/tests/system/data/test_system.py @@ -0,0 +1,942 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio +import asyncio +import uuid +import os +from google.api_core import retry +from google.api_core.exceptions import ClientError + +from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE +from google.cloud.environment_vars import BIGTABLE_EMULATOR + +TEST_FAMILY = "test-family" +TEST_FAMILY_2 = "test-family-2" + + +@pytest.fixture(scope="session") +def column_family_config(): + """ + specify column families to create when creating a new test table + """ + from google.cloud.bigtable_admin_v2 import types + + return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} + + +@pytest.fixture(scope="session") +def init_table_id(): + """ + The table_id to use when creating a new test table + """ + return f"test-table-{uuid.uuid4().hex}" + + +@pytest.fixture(scope="session") +def cluster_config(project_id): + """ + Configuration for the clusters to use when creating a new instance + """ + from google.cloud.bigtable_admin_v2 import types + + cluster = { + "test-cluster": types.Cluster( + location=f"projects/{project_id}/locations/us-central1-b", + serve_nodes=1, + ) + } + return cluster + + +class TempRowBuilder: + """ + Used to add rows to a table for testing purposes. + """ + + def __init__(self, table): + self.rows = [] + self.table = table + + async def add_row( + self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" + ): + if isinstance(value, str): + value = value.encode("utf-8") + elif isinstance(value, int): + value = value.to_bytes(8, byteorder="big", signed=True) + request = { + "table_name": self.table.table_name, + "row_key": row_key, + "mutations": [ + { + "set_cell": { + "family_name": family, + "column_qualifier": qualifier, + "value": value, + } + } + ], + } + await self.table.client._gapic_client.mutate_row(request) + self.rows.append(row_key) + + async def delete_rows(self): + if self.rows: + request = { + "table_name": self.table.table_name, + "entries": [ + {"row_key": row, "mutations": [{"delete_from_row": {}}]} + for row in self.rows + ], + } + await self.table.client._gapic_client.mutate_rows(request) + + +@pytest.mark.usefixtures("table") +async def _retrieve_cell_value(table, row_key): + """ + Helper to read an individual row + """ + from google.cloud.bigtable.data import ReadRowsQuery + + row_list = await table.read_rows(ReadRowsQuery(row_keys=row_key)) + assert len(row_list) == 1 + row = row_list[0] + cell = row.cells[0] + return cell.value + + +async def _create_row_and_mutation( + table, temp_rows, *, start_value=b"start", new_value=b"new_value" +): + """ + Helper to create a new row, and a sample set_cell mutation to change its value + """ + from google.cloud.bigtable.data.mutations import SetCell + + row_key = uuid.uuid4().hex.encode() + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row( + row_key, family=family, qualifier=qualifier, value=start_value + ) + # ensure cell is initialized + assert (await _retrieve_cell_value(table, row_key)) == start_value + + mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) + return row_key, mutation + + +@pytest_asyncio.fixture(scope="function") +async def temp_rows(table): + builder = TempRowBuilder(table) + yield builder + await builder.delete_rows() + + +@pytest.mark.usefixtures("table") +@pytest.mark.usefixtures("client") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=10) +@pytest.mark.asyncio +async def test_ping_and_warm_gapic(client, table): + """ + Simple ping rpc test + This test ensures channels are able to authenticate with backend + """ + request = {"name": table.instance_name} + await client._gapic_client.ping_and_warm(request) + + +@pytest.mark.usefixtures("table") +@pytest.mark.usefixtures("client") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_ping_and_warm(client, table): + """ + Test ping and warm from handwritten client + """ + try: + channel = client.transport._grpc_channel.pool[0] + except Exception: + # for sync client + channel = client.transport._grpc_channel + results = await client._ping_and_warm_instances(channel) + assert len(results) == 1 + assert results[0] is None + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +async def test_mutation_set_cell(table, temp_rows): + """ + Ensure cells can be set properly + """ + row_key = b"bulk_mutate" + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + await table.mutate_row(row_key, mutation) + + # ensure cell is updated + assert (await _retrieve_cell_value(table, row_key)) == new_value + + +@pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" +) +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_sample_row_keys(client, table, temp_rows, column_split_config): + """ + Sample keys should return a single sample in small test tables + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + + results = await table.sample_row_keys() + assert len(results) == len(column_split_config) + 1 + # first keys should match the split config + for idx in range(len(column_split_config)): + assert results[idx][0] == column_split_config[idx] + assert isinstance(results[idx][1], int) + # last sample should be empty key + assert results[-1][0] == b"" + assert isinstance(results[-1][1], int) + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_bulk_mutations_set_cell(client, table, temp_rows): + """ + Ensure cells can be set properly + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + + await table.bulk_mutate_rows([bulk_mutation]) + + # ensure cell is updated + assert (await _retrieve_cell_value(table, row_key)) == new_value + + +@pytest.mark.asyncio +async def test_bulk_mutations_raise_exception(client, table): + """ + If an invalid mutation is passed, an exception should be raised + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + row_key = uuid.uuid4().hex.encode() + mutation = SetCell(family="nonexistent", qualifier=b"test-qualifier", new_value=b"") + bulk_mutation = RowMutationEntry(row_key, [mutation]) + + with pytest.raises(MutationsExceptionGroup) as exc: + await table.bulk_mutate_rows([bulk_mutation]) + assert len(exc.value.exceptions) == 1 + entry_error = exc.value.exceptions[0] + assert isinstance(entry_error, FailedMutationEntryError) + assert entry_error.index == 0 + assert entry_error.entry == bulk_mutation + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_mutations_batcher_context_manager(client, table, temp_rows): + """ + test batcher with context manager. Should flush on exit + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + row_key2, mutation2 = await _create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + async with table.mutations_batcher() as batcher: + await batcher.append(bulk_mutation) + await batcher.append(bulk_mutation2) + # ensure cell is updated + assert (await _retrieve_cell_value(table, row_key)) == new_value + assert len(batcher._staged_entries) == 0 + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_mutations_batcher_timer_flush(client, table, temp_rows): + """ + batch should occur after flush_interval seconds + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + flush_interval = 0.1 + async with table.mutations_batcher(flush_interval=flush_interval) as batcher: + await batcher.append(bulk_mutation) + await asyncio.sleep(0) + assert len(batcher._staged_entries) == 1 + await asyncio.sleep(flush_interval + 0.1) + assert len(batcher._staged_entries) == 0 + # ensure cell is updated + assert (await _retrieve_cell_value(table, row_key)) == new_value + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_mutations_batcher_count_flush(client, table, temp_rows): + """ + batch should flush after flush_limit_mutation_count mutations + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await _create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + async with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + # should be noop; flush not scheduled + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # task should now be scheduled + assert len(batcher._flush_jobs) == 1 + await asyncio.gather(*batcher._flush_jobs) + assert len(batcher._staged_entries) == 0 + assert len(batcher._flush_jobs) == 0 + # ensure cells were updated + assert (await _retrieve_cell_value(table, row_key)) == new_value + assert (await _retrieve_cell_value(table, row_key2)) == new_value2 + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_mutations_batcher_bytes_flush(client, table, temp_rows): + """ + batch should flush after flush_limit_bytes bytes + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await _create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 + + async with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # task should now be scheduled + assert len(batcher._flush_jobs) == 1 + assert len(batcher._staged_entries) == 0 + # let flush complete + await asyncio.gather(*batcher._flush_jobs) + # ensure cells were updated + assert (await _retrieve_cell_value(table, row_key)) == new_value + assert (await _retrieve_cell_value(table, row_key2)) == new_value2 + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_mutations_batcher_no_flush(client, table, temp_rows): + """ + test with no flush requirements met + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + start_value = b"unchanged" + row_key, mutation = await _create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await _create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 + async with table.mutations_batcher( + flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 + ) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # flush not scheduled + assert len(batcher._flush_jobs) == 0 + await asyncio.sleep(0.01) + assert len(batcher._staged_entries) == 2 + assert len(batcher._flush_jobs) == 0 + # ensure cells were not updated + assert (await _retrieve_cell_value(table, row_key)) == start_value + assert (await _retrieve_cell_value(table, row_key2)) == start_value + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.parametrize( + "start,increment,expected", + [ + (0, 0, 0), + (0, 1, 1), + (0, -1, -1), + (1, 0, 1), + (0, -100, -100), + (0, 3000, 3000), + (10, 4, 14), + (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), + (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), + (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), + ], +) +@pytest.mark.asyncio +async def test_read_modify_write_row_increment( + client, table, temp_rows, start, increment, expected +): + """ + test read_modify_write_row + """ + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + + rule = IncrementRule(family, qualifier, increment) + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert int(result[0]) == expected + # ensure that reading from server gives same value + assert (await _retrieve_cell_value(table, row_key)) == result[0].value + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.parametrize( + "start,append,expected", + [ + (b"", b"", b""), + ("", "", b""), + (b"abc", b"123", b"abc123"), + (b"abc", "123", b"abc123"), + ("", b"1", b"1"), + (b"abc", "", b"abc"), + (b"hello", b"world", b"helloworld"), + ], +) +@pytest.mark.asyncio +async def test_read_modify_write_row_append( + client, table, temp_rows, start, append, expected +): + """ + test read_modify_write_row + """ + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + + rule = AppendValueRule(family, qualifier, append) + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert result[0].value == expected + # ensure that reading from server gives same value + assert (await _retrieve_cell_value(table, row_key)) == result[0].value + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_read_modify_write_row_chained(client, table, temp_rows): + """ + test read_modify_write_row with multiple rules + """ + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + start_amount = 1 + increment_amount = 10 + await temp_rows.add_row( + row_key, value=start_amount, family=family, qualifier=qualifier + ) + rule = [ + IncrementRule(family, qualifier, increment_amount), + AppendValueRule(family, qualifier, "hello"), + AppendValueRule(family, qualifier, "world"), + AppendValueRule(family, qualifier, "!"), + ] + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert result[0].family == family + assert result[0].qualifier == qualifier + # result should be a bytes number string for the IncrementRules, followed by the AppendValueRule values + assert ( + result[0].value + == (start_amount + increment_amount).to_bytes(8, "big", signed=True) + + b"helloworld!" + ) + # ensure that reading from server gives same value + assert (await _retrieve_cell_value(table, row_key)) == result[0].value + + +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.parametrize( + "start_val,predicate_range,expected_result", + [ + (1, (0, 2), True), + (-1, (0, 2), False), + ], +) +@pytest.mark.asyncio +async def test_check_and_mutate( + client, table, temp_rows, start_val, predicate_range, expected_result +): + """ + test that check_and_mutate_row works applies the right mutations, and returns the right result + """ + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + + await temp_rows.add_row( + row_key, value=start_val, family=family, qualifier=qualifier + ) + + false_mutation_value = b"false-mutation-value" + false_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value + ) + true_mutation_value = b"true-mutation-value" + true_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value + ) + predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) + result = await table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + assert result == expected_result + # ensure cell is updated + expected_value = true_mutation_value if expected_result else false_mutation_value + assert (await _retrieve_cell_value(table, row_key)) == expected_value + + +@pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", +) +@pytest.mark.usefixtures("client") +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_check_and_mutate_empty_request(client, table): + """ + check_and_mutate with no true or fale mutations should raise an error + """ + from google.api_core import exceptions + + with pytest.raises(exceptions.InvalidArgument) as e: + await table.check_and_mutate_row( + b"row_key", None, true_case_mutations=None, false_case_mutations=None + ) + assert "No mutations provided" in str(e.value) + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_stream(table, temp_rows): + """ + Ensure that the read_rows_stream method works + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + + # full table scan + generator = await table.read_rows_stream({}) + first_row = await generator.__anext__() + second_row = await generator.__anext__() + assert first_row.row_key == b"row_key_1" + assert second_row.row_key == b"row_key_2" + with pytest.raises(StopAsyncIteration): + await generator.__anext__() + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows(table, temp_rows): + """ + Ensure that the read_rows method works + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + # full table scan + row_list = await table.read_rows({}) + assert len(row_list) == 2 + assert row_list[0].row_key == b"row_key_1" + assert row_list[1].row_key == b"row_key_2" + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_sharded_simple(table, temp_rows): + """ + Test read rows sharded with two queries + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) + row_list = await table.read_rows_sharded([query1, query2]) + assert len(row_list) == 4 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"b" + assert row_list[3].row_key == b"d" + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_sharded_from_sample(table, temp_rows): + """ + Test end-to-end sharding + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.read_rows_query import RowRange + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + + table_shard_keys = await table.sample_row_keys() + query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) + shard_queries = query.shard(table_shard_keys) + row_list = await table.read_rows_sharded(shard_queries) + assert len(row_list) == 3 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"d" + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_sharded_filters_limits(table, temp_rows): + """ + Test read rows sharded with filters and limits + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + + label_filter1 = ApplyLabelFilter("first") + label_filter2 = ApplyLabelFilter("second") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) + row_list = await table.read_rows_sharded([query1, query2]) + assert len(row_list) == 3 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"b" + assert row_list[2].row_key == b"d" + assert row_list[0][0].labels == ["first"] + assert row_list[1][0].labels == ["second"] + assert row_list[2][0].labels == ["second"] + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_range_query(table, temp_rows): + """ + Ensure that the read_rows method works + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data import RowRange + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # full table scan + query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) + row_list = await table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_single_key_query(table, temp_rows): + """ + Ensure that the read_rows method works with specified query + """ + from google.cloud.bigtable.data import ReadRowsQuery + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # retrieve specific keys + query = ReadRowsQuery(row_keys=[b"a", b"c"]) + row_list = await table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.asyncio +async def test_read_rows_with_filter(table, temp_rows): + """ + ensure filters are applied + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # retrieve keys with filter + expected_label = "test-label" + row_filter = ApplyLabelFilter(expected_label) + query = ReadRowsQuery(row_filter=row_filter) + row_list = await table.read_rows(query) + assert len(row_list) == 4 + for row in row_list: + assert row[0].labels == [expected_label] + + +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_read_rows_stream_close(table, temp_rows): + """ + Ensure that the read_rows_stream can be closed + """ + from google.cloud.bigtable.data import ReadRowsQuery + + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + # full table scan + query = ReadRowsQuery() + generator = await table.read_rows_stream(query) + # grab first row + first_row = await generator.__anext__() + assert first_row.row_key == b"row_key_1" + # close stream early + await generator.aclose() + with pytest.raises(StopAsyncIteration): + await generator.__anext__() + + +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_read_row(table, temp_rows): + """ + Test read_row (single row helper) + """ + from google.cloud.bigtable.data import Row + + await temp_rows.add_row(b"row_key_1", value=b"value") + row = await table.read_row(b"row_key_1") + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + + +@pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", +) +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_read_row_missing(table): + """ + Test read_row when row does not exist + """ + from google.api_core import exceptions + + row_key = "row_key_not_exist" + result = await table.read_row(row_key) + assert result is None + with pytest.raises(exceptions.InvalidArgument) as e: + await table.read_row("") + assert "Row keys must be non-empty" in str(e) + + +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_read_row_w_filter(table, temp_rows): + """ + Test read_row (single row helper) + """ + from google.cloud.bigtable.data import Row + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"row_key_1", value=b"value") + expected_label = "test-label" + label_filter = ApplyLabelFilter(expected_label) + row = await table.read_row(b"row_key_1", row_filter=label_filter) + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + assert row.cells[0].labels == [expected_label] + + +@pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", +) +@pytest.mark.usefixtures("table") +@pytest.mark.asyncio +async def test_row_exists(table, temp_rows): + from google.api_core import exceptions + + """Test row_exists with rows that exist and don't exist""" + assert await table.row_exists(b"row_key_1") is False + await temp_rows.add_row(b"row_key_1") + assert await table.row_exists(b"row_key_1") is True + assert await table.row_exists("row_key_1") is True + assert await table.row_exists(b"row_key_2") is False + assert await table.row_exists("row_key_2") is False + assert await table.row_exists("3") is False + await temp_rows.add_row(b"3") + assert await table.row_exists(b"3") is True + with pytest.raises(exceptions.InvalidArgument) as e: + await table.row_exists("") + assert "Row keys must be non-empty" in str(e) + + +@pytest.mark.usefixtures("table") +@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) +@pytest.mark.parametrize( + "cell_value,filter_input,expect_match", + [ + (b"abc", b"abc", True), + (b"abc", "abc", True), + (b".", ".", True), + (".*", ".*", True), + (".*", b".*", True), + ("a", ".*", False), + (b".*", b".*", True), + (r"\a", r"\a", True), + (b"\xe2\x98\x83", "☃", True), + ("☃", "☃", True), + (r"\C☃", r"\C☃", True), + (1, 1, True), + (2, 1, False), + (68, 68, True), + ("D", 68, False), + (68, "D", False), + (-1, -1, True), + (2852126720, 2852126720, True), + (-1431655766, -1431655766, True), + (-1431655766, -1, False), + ], +) +@pytest.mark.asyncio +async def test_literal_value_filter( + table, temp_rows, cell_value, filter_input, expect_match +): + """ + Literal value filter does complex escaping on re2 strings. + Make sure inputs are properly interpreted by the server + """ + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + from google.cloud.bigtable.data import ReadRowsQuery + + f = LiteralValueFilter(filter_input) + await temp_rows.add_row(b"row_key_1", value=cell_value) + query = ReadRowsQuery(row_filter=f) + row_list = await table.read_rows(query) + assert len(row_list) == bool( + expect_match + ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py deleted file mode 100644 index d12936305..000000000 --- a/tests/system/data/test_system_async.py +++ /dev/null @@ -1,992 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import asyncio -import uuid -import os -from google.api_core import retry -from google.api_core.exceptions import ClientError - -from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE -from google.cloud.environment_vars import BIGTABLE_EMULATOR - -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -from . import TEST_FAMILY, TEST_FAMILY_2 - - -@CrossSync.export_sync( - path="tests.system.data.test_system.TempRowBuilder", - add_mapping_for_name="TempRowBuilder", -) -class TempRowBuilderAsync: - """ - Used to add rows to a table for testing purposes. - """ - - def __init__(self, table): - self.rows = [] - self.table = table - - @CrossSync.convert - async def add_row( - self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" - ): - if isinstance(value, str): - value = value.encode("utf-8") - elif isinstance(value, int): - value = value.to_bytes(8, byteorder="big", signed=True) - request = { - "table_name": self.table.table_name, - "row_key": row_key, - "mutations": [ - { - "set_cell": { - "family_name": family, - "column_qualifier": qualifier, - "value": value, - } - } - ], - } - await self.table.client._gapic_client.mutate_row(request) - self.rows.append(row_key) - - @CrossSync.convert - async def delete_rows(self): - if self.rows: - request = { - "table_name": self.table.table_name, - "entries": [ - {"row_key": row, "mutations": [{"delete_from_row": {}}]} - for row in self.rows - ], - } - await self.table.client._gapic_client.mutate_rows(request) - - -@CrossSync.export_sync(path="tests.system.data.test_system.TestSystem") -class TestSystemAsync: - @CrossSync.convert - @CrossSync.pytest_fixture(scope="session") - async def client(self): - project = os.getenv("GOOGLE_CLOUD_PROJECT") or None - async with CrossSync.DataClient(project=project, pool_size=4) as client: - yield client - - @CrossSync.convert - @CrossSync.pytest_fixture(scope="session") - async def table(self, client, table_id, instance_id): - async with client.get_table( - instance_id, - table_id, - ) as table: - yield table - - @pytest.fixture(scope="session") - def event_loop(self): - loop = asyncio.get_event_loop() - yield loop - loop.stop() - loop.close() - - @pytest.fixture(scope="session") - def column_family_config(self): - """ - specify column families to create when creating a new test table - """ - from google.cloud.bigtable_admin_v2 import types - - return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} - - @pytest.fixture(scope="session") - def init_table_id(self): - """ - The table_id to use when creating a new test table - """ - return f"test-table-{uuid.uuid4().hex}" - - @pytest.fixture(scope="session") - def cluster_config(self, project_id): - """ - Configuration for the clusters to use when creating a new instance - """ - from google.cloud.bigtable_admin_v2 import types - - cluster = { - "test-cluster": types.Cluster( - location=f"projects/{project_id}/locations/us-central1-b", - serve_nodes=1, - ) - } - return cluster - - @CrossSync.convert - @pytest.mark.usefixtures("table") - async def _retrieve_cell_value(self, table, row_key): - """ - Helper to read an individual row - """ - from google.cloud.bigtable.data import ReadRowsQuery - - row_list = await table.read_rows(ReadRowsQuery(row_keys=row_key)) - assert len(row_list) == 1 - row = row_list[0] - cell = row.cells[0] - return cell.value - - @CrossSync.convert - async def _create_row_and_mutation( - self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" - ): - """ - Helper to create a new row, and a sample set_cell mutation to change its value - """ - from google.cloud.bigtable.data.mutations import SetCell - - row_key = uuid.uuid4().hex.encode() - family = TEST_FAMILY - qualifier = b"test-qualifier" - await temp_rows.add_row( - row_key, family=family, qualifier=qualifier, value=start_value - ) - # ensure cell is initialized - assert (await self._retrieve_cell_value(table, row_key)) == start_value - - mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) - return row_key, mutation - - @CrossSync.convert - @CrossSync.pytest_fixture(scope="function") - async def temp_rows(self, table): - builder = CrossSync.TempRowBuilder(table) - yield builder - await builder.delete_rows() - - @pytest.mark.usefixtures("table") - @pytest.mark.usefixtures("client") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=10 - ) - @CrossSync.pytest - async def test_ping_and_warm_gapic(self, client, table): - """ - Simple ping rpc test - This test ensures channels are able to authenticate with backend - """ - request = {"name": table.instance_name} - await client._gapic_client.ping_and_warm(request) - - @pytest.mark.usefixtures("table") - @pytest.mark.usefixtures("client") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_ping_and_warm(self, client, table): - """ - Test ping and warm from handwritten client - """ - try: - channel = client.transport._grpc_channel.pool[0] - except Exception: - # for sync client - channel = client.transport._grpc_channel - results = await client._ping_and_warm_instances(channel) - assert len(results) == 1 - assert results[0] is None - - @CrossSync.pytest - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - async def test_mutation_set_cell(self, table, temp_rows): - """ - Ensure cells can be set properly - """ - row_key = b"bulk_mutate" - new_value = uuid.uuid4().hex.encode() - row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - await table.mutate_row(row_key, mutation) - - # ensure cell is updated - assert (await self._retrieve_cell_value(table, row_key)) == new_value - - @pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" - ) - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_sample_row_keys(self, client, table, temp_rows, column_split_config): - """ - Sample keys should return a single sample in small test tables - """ - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - - results = await table.sample_row_keys() - assert len(results) == len(column_split_config) + 1 - # first keys should match the split config - for idx in range(len(column_split_config)): - assert results[idx][0] == column_split_config[idx] - assert isinstance(results[idx][1], int) - # last sample should be empty key - assert results[-1][0] == b"" - assert isinstance(results[-1][1], int) - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync.pytest - async def test_bulk_mutations_set_cell(self, client, table, temp_rows): - """ - Ensure cells can be set properly - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - - await table.bulk_mutate_rows([bulk_mutation]) - - # ensure cell is updated - assert (await self._retrieve_cell_value(table, row_key)) == new_value - - @CrossSync.pytest - async def test_bulk_mutations_raise_exception(self, client, table): - """ - If an invalid mutation is passed, an exception should be raised - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell - from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup - from google.cloud.bigtable.data.exceptions import FailedMutationEntryError - - row_key = uuid.uuid4().hex.encode() - mutation = SetCell( - family="nonexistent", qualifier=b"test-qualifier", new_value=b"" - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - - with pytest.raises(MutationsExceptionGroup) as exc: - await table.bulk_mutate_rows([bulk_mutation]) - assert len(exc.value.exceptions) == 1 - entry_error = exc.value.exceptions[0] - assert isinstance(entry_error, FailedMutationEntryError) - assert entry_error.index == 0 - assert entry_error.entry == bulk_mutation - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_mutations_batcher_context_manager(self, client, table, temp_rows): - """ - test batcher with context manager. Should flush on exit - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - row_key2, mutation2 = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - async with table.mutations_batcher() as batcher: - await batcher.append(bulk_mutation) - await batcher.append(bulk_mutation2) - # ensure cell is updated - assert (await self._retrieve_cell_value(table, row_key)) == new_value - assert len(batcher._staged_entries) == 0 - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_mutations_batcher_timer_flush(self, client, table, temp_rows): - """ - batch should occur after flush_interval seconds - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - flush_interval = 0.1 - async with table.mutations_batcher(flush_interval=flush_interval) as batcher: - await batcher.append(bulk_mutation) - await CrossSync.yield_to_event_loop() - assert len(batcher._staged_entries) == 1 - await CrossSync.sleep(flush_interval + 0.1) - assert len(batcher._staged_entries) == 0 - # ensure cell is updated - assert (await self._retrieve_cell_value(table, row_key)) == new_value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_mutations_batcher_count_flush(self, client, table, temp_rows): - """ - batch should flush after flush_limit_mutation_count mutations - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - async with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: - await batcher.append(bulk_mutation) - assert len(batcher._flush_jobs) == 0 - # should be noop; flush not scheduled - assert len(batcher._staged_entries) == 1 - await batcher.append(bulk_mutation2) - # task should now be scheduled - assert len(batcher._flush_jobs) == 1 - # let flush complete - for future in list(batcher._flush_jobs): - await future - # for sync version: grab result - future.result() - assert len(batcher._staged_entries) == 0 - assert len(batcher._flush_jobs) == 0 - # ensure cells were updated - assert (await self._retrieve_cell_value(table, row_key)) == new_value - assert (await self._retrieve_cell_value(table, row_key2)) == new_value2 - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): - """ - batch should flush after flush_limit_bytes bytes - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = await self._create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 - - async with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: - await batcher.append(bulk_mutation) - assert len(batcher._flush_jobs) == 0 - assert len(batcher._staged_entries) == 1 - await batcher.append(bulk_mutation2) - # task should now be scheduled - assert len(batcher._flush_jobs) == 1 - assert len(batcher._staged_entries) == 0 - # let flush complete - for future in list(batcher._flush_jobs): - await future - # for sync version: grab result - future.result() - # ensure cells were updated - assert (await self._retrieve_cell_value(table, row_key)) == new_value - assert (await self._retrieve_cell_value(table, row_key2)) == new_value2 - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync.pytest - async def test_mutations_batcher_no_flush(self, client, table, temp_rows): - """ - test with no flush requirements met - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - start_value = b"unchanged" - row_key, mutation = await self._create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = await self._create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 - async with table.mutations_batcher( - flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 - ) as batcher: - await batcher.append(bulk_mutation) - assert len(batcher._staged_entries) == 1 - await batcher.append(bulk_mutation2) - # flush not scheduled - assert len(batcher._flush_jobs) == 0 - await CrossSync.yield_to_event_loop() - assert len(batcher._staged_entries) == 2 - assert len(batcher._flush_jobs) == 0 - # ensure cells were not updated - assert (await self._retrieve_cell_value(table, row_key)) == start_value - assert (await self._retrieve_cell_value(table, row_key2)) == start_value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @pytest.mark.parametrize( - "start,increment,expected", - [ - (0, 0, 0), - (0, 1, 1), - (0, -1, -1), - (1, 0, 1), - (0, -100, -100), - (0, 3000, 3000), - (10, 4, 14), - (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), - (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), - (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), - ], - ) - @CrossSync.pytest - async def test_read_modify_write_row_increment( - self, client, table, temp_rows, start, increment, expected - ): - """ - test read_modify_write_row - """ - from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - await temp_rows.add_row( - row_key, value=start, family=family, qualifier=qualifier - ) - - rule = IncrementRule(family, qualifier, increment) - result = await table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert len(result) == 1 - assert result[0].family == family - assert result[0].qualifier == qualifier - assert int(result[0]) == expected - # ensure that reading from server gives same value - assert (await self._retrieve_cell_value(table, row_key)) == result[0].value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @pytest.mark.parametrize( - "start,append,expected", - [ - (b"", b"", b""), - ("", "", b""), - (b"abc", b"123", b"abc123"), - (b"abc", "123", b"abc123"), - ("", b"1", b"1"), - (b"abc", "", b"abc"), - (b"hello", b"world", b"helloworld"), - ], - ) - @CrossSync.pytest - async def test_read_modify_write_row_append( - self, client, table, temp_rows, start, append, expected - ): - """ - test read_modify_write_row - """ - from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - await temp_rows.add_row( - row_key, value=start, family=family, qualifier=qualifier - ) - - rule = AppendValueRule(family, qualifier, append) - result = await table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert len(result) == 1 - assert result[0].family == family - assert result[0].qualifier == qualifier - assert result[0].value == expected - # ensure that reading from server gives same value - assert (await self._retrieve_cell_value(table, row_key)) == result[0].value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync.pytest - async def test_read_modify_write_row_chained(self, client, table, temp_rows): - """ - test read_modify_write_row with multiple rules - """ - from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule - from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - start_amount = 1 - increment_amount = 10 - await temp_rows.add_row( - row_key, value=start_amount, family=family, qualifier=qualifier - ) - rule = [ - IncrementRule(family, qualifier, increment_amount), - AppendValueRule(family, qualifier, "hello"), - AppendValueRule(family, qualifier, "world"), - AppendValueRule(family, qualifier, "!"), - ] - result = await table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert result[0].family == family - assert result[0].qualifier == qualifier - # result should be a bytes number string for the IncrementRules, followed by the AppendValueRule values - assert ( - result[0].value - == (start_amount + increment_amount).to_bytes(8, "big", signed=True) - + b"helloworld!" - ) - # ensure that reading from server gives same value - assert (await self._retrieve_cell_value(table, row_key)) == result[0].value - - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @pytest.mark.parametrize( - "start_val,predicate_range,expected_result", - [ - (1, (0, 2), True), - (-1, (0, 2), False), - ], - ) - @CrossSync.pytest - async def test_check_and_mutate( - self, client, table, temp_rows, start_val, predicate_range, expected_result - ): - """ - test that check_and_mutate_row works applies the right mutations, and returns the right result - """ - from google.cloud.bigtable.data.mutations import SetCell - from google.cloud.bigtable.data.row_filters import ValueRangeFilter - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - - await temp_rows.add_row( - row_key, value=start_val, family=family, qualifier=qualifier - ) - - false_mutation_value = b"false-mutation-value" - false_mutation = SetCell( - family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value - ) - true_mutation_value = b"true-mutation-value" - true_mutation = SetCell( - family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value - ) - predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) - result = await table.check_and_mutate_row( - row_key, - predicate, - true_case_mutations=true_mutation, - false_case_mutations=false_mutation, - ) - assert result == expected_result - # ensure cell is updated - expected_value = ( - true_mutation_value if expected_result else false_mutation_value - ) - assert (await self._retrieve_cell_value(table, row_key)) == expected_value - - @pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", - ) - @pytest.mark.usefixtures("client") - @pytest.mark.usefixtures("table") - @CrossSync.pytest - async def test_check_and_mutate_empty_request(self, client, table): - """ - check_and_mutate with no true or fale mutations should raise an error - """ - from google.api_core import exceptions - - with pytest.raises(exceptions.InvalidArgument) as e: - await table.check_and_mutate_row( - b"row_key", None, true_case_mutations=None, false_case_mutations=None - ) - assert "No mutations provided" in str(e.value) - - @pytest.mark.usefixtures("table") - @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_read_rows_stream(self, table, temp_rows): - """ - Ensure that the read_rows_stream method works - """ - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - - # full table scan - generator = await table.read_rows_stream({}) - first_row = await generator.__anext__() - second_row = await generator.__anext__() - assert first_row.row_key == b"row_key_1" - assert second_row.row_key == b"row_key_2" - with pytest.raises(CrossSync.StopIteration): - await generator.__anext__() - - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_read_rows(self, table, temp_rows): - """ - Ensure that the read_rows method works - """ - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - # full table scan - row_list = await table.read_rows({}) - assert len(row_list) == 2 - assert row_list[0].row_key == b"row_key_1" - assert row_list[1].row_key == b"row_key_2" - - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_read_rows_sharded_simple(self, table, temp_rows): - """ - Test read rows sharded with two queries - """ - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) - query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) - row_list = await table.read_rows_sharded([query1, query2]) - assert len(row_list) == 4 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"c" - assert row_list[2].row_key == b"b" - assert row_list[3].row_key == b"d" - - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_read_rows_sharded_from_sample(self, table, temp_rows): - """ - Test end-to-end sharding - """ - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - from google.cloud.bigtable.data.read_rows_query import RowRange - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - - table_shard_keys = await table.sample_row_keys() - query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) - shard_queries = query.shard(table_shard_keys) - row_list = await table.read_rows_sharded(shard_queries) - assert len(row_list) == 3 - assert row_list[0].row_key == b"b" - assert row_list[1].row_key == b"c" - assert row_list[2].row_key == b"d" - - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_read_rows_sharded_filters_limits(self, table, temp_rows): - """ - Test read rows sharded with filters and limits - """ - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - - label_filter1 = ApplyLabelFilter("first") - label_filter2 = ApplyLabelFilter("second") - query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) - query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) - row_list = await table.read_rows_sharded([query1, query2]) - assert len(row_list) == 3 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"b" - assert row_list[2].row_key == b"d" - assert row_list[0][0].labels == ["first"] - assert row_list[1][0].labels == ["second"] - assert row_list[2][0].labels == ["second"] - - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_read_rows_range_query(self, table, temp_rows): - """ - Ensure that the read_rows method works - """ - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable.data import RowRange - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - # full table scan - query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) - row_list = await table.read_rows(query) - assert len(row_list) == 2 - assert row_list[0].row_key == b"b" - assert row_list[1].row_key == b"c" - - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_read_rows_single_key_query(self, table, temp_rows): - """ - Ensure that the read_rows method works with specified query - """ - from google.cloud.bigtable.data import ReadRowsQuery - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - # retrieve specific keys - query = ReadRowsQuery(row_keys=[b"a", b"c"]) - row_list = await table.read_rows(query) - assert len(row_list) == 2 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"c" - - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @CrossSync.pytest - async def test_read_rows_with_filter(self, table, temp_rows): - """ - ensure filters are applied - """ - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - # retrieve keys with filter - expected_label = "test-label" - row_filter = ApplyLabelFilter(expected_label) - query = ReadRowsQuery(row_filter=row_filter) - row_list = await table.read_rows(query) - assert len(row_list) == 4 - for row in row_list: - assert row[0].labels == [expected_label] - - @pytest.mark.usefixtures("table") - @CrossSync.convert(replace_symbols={"__anext__": "__next__", "aclose": "close"}) - @CrossSync.pytest - async def test_read_rows_stream_close(self, table, temp_rows): - """ - Ensure that the read_rows_stream can be closed - """ - from google.cloud.bigtable.data import ReadRowsQuery - - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - # full table scan - query = ReadRowsQuery() - generator = await table.read_rows_stream(query) - # grab first row - first_row = await generator.__anext__() - assert first_row.row_key == b"row_key_1" - # close stream early - await generator.aclose() - with pytest.raises(CrossSync.StopIteration): - await generator.__anext__() - - @pytest.mark.usefixtures("table") - @CrossSync.pytest - async def test_read_row(self, table, temp_rows): - """ - Test read_row (single row helper) - """ - from google.cloud.bigtable.data import Row - - await temp_rows.add_row(b"row_key_1", value=b"value") - row = await table.read_row(b"row_key_1") - assert isinstance(row, Row) - assert row.row_key == b"row_key_1" - assert row.cells[0].value == b"value" - - @pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", - ) - @pytest.mark.usefixtures("table") - @CrossSync.pytest - async def test_read_row_missing(self, table): - """ - Test read_row when row does not exist - """ - from google.api_core import exceptions - - row_key = "row_key_not_exist" - result = await table.read_row(row_key) - assert result is None - with pytest.raises(exceptions.InvalidArgument) as e: - await table.read_row("") - assert "Row keys must be non-empty" in str(e) - - @pytest.mark.usefixtures("table") - @CrossSync.pytest - async def test_read_row_w_filter(self, table, temp_rows): - """ - Test read_row (single row helper) - """ - from google.cloud.bigtable.data import Row - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - await temp_rows.add_row(b"row_key_1", value=b"value") - expected_label = "test-label" - label_filter = ApplyLabelFilter(expected_label) - row = await table.read_row(b"row_key_1", row_filter=label_filter) - assert isinstance(row, Row) - assert row.row_key == b"row_key_1" - assert row.cells[0].value == b"value" - assert row.cells[0].labels == [expected_label] - - @pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", - ) - @pytest.mark.usefixtures("table") - @CrossSync.pytest - async def test_row_exists(self, table, temp_rows): - from google.api_core import exceptions - - """Test row_exists with rows that exist and don't exist""" - assert await table.row_exists(b"row_key_1") is False - await temp_rows.add_row(b"row_key_1") - assert await table.row_exists(b"row_key_1") is True - assert await table.row_exists("row_key_1") is True - assert await table.row_exists(b"row_key_2") is False - assert await table.row_exists("row_key_2") is False - assert await table.row_exists("3") is False - await temp_rows.add_row(b"3") - assert await table.row_exists(b"3") is True - with pytest.raises(exceptions.InvalidArgument) as e: - await table.row_exists("") - assert "Row keys must be non-empty" in str(e) - - @pytest.mark.usefixtures("table") - @CrossSync.Retry( - predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 - ) - @pytest.mark.parametrize( - "cell_value,filter_input,expect_match", - [ - (b"abc", b"abc", True), - (b"abc", "abc", True), - (b".", ".", True), - (".*", ".*", True), - (".*", b".*", True), - ("a", ".*", False), - (b".*", b".*", True), - (r"\a", r"\a", True), - (b"\xe2\x98\x83", "☃", True), - ("☃", "☃", True), - (r"\C☃", r"\C☃", True), - (1, 1, True), - (2, 1, False), - (68, 68, True), - ("D", 68, False), - (68, "D", False), - (-1, -1, True), - (2852126720, 2852126720, True), - (-1431655766, -1431655766, True), - (-1431655766, -1, False), - ], - ) - @CrossSync.pytest - async def test_literal_value_filter( - self, table, temp_rows, cell_value, filter_input, expect_match - ): - """ - Literal value filter does complex escaping on re2 strings. - Make sure inputs are properly interpreted by the server - """ - from google.cloud.bigtable.data.row_filters import LiteralValueFilter - from google.cloud.bigtable.data import ReadRowsQuery - - f = LiteralValueFilter(filter_input) - await temp_rows.add_row(b"row_key_1", value=cell_value) - query = ReadRowsQuery(row_filter=f) - row_list = await table.read_rows(query) - assert len(row_list) == bool( - expect_match - ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" diff --git a/tests/unit/data/_async/__init__.py b/tests/unit/data/_async/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index a307a7008..e03028c45 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -16,42 +16,42 @@ from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 -from google.api_core.exceptions import DeadlineExceeded -from google.api_core.exceptions import Forbidden - -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +import google.api_core.exceptions as core_exceptions # try/except added for compatibility with python < 3.8 try: from unittest import mock + from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER import mock # type: ignore + from mock import AsyncMock # type: ignore + + +def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation -@CrossSync.export_sync( - path="tests.unit.data._sync.test__mutate_rows.TestMutateRowsOperation", -) class TestMutateRowsOperation: def _target_class(self): - return CrossSync._MutateRowsOperation + from google.cloud.bigtable.data._async._mutate_rows import ( + _MutateRowsOperationAsync, + ) + + return _MutateRowsOperationAsync def _make_one(self, *args, **kwargs): if not args: kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) - kwargs["table"] = kwargs.pop("table", CrossSync.Mock()) + kwargs["table"] = kwargs.pop("table", AsyncMock()) kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) return self._target_class()(*args, **kwargs) - def _make_mutation(self, count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation - - @CrossSync.convert async def _mock_stream(self, mutation_list, error_dict): for idx, entry in enumerate(mutation_list): code = error_dict.get(idx, 0) @@ -64,7 +64,7 @@ async def _mock_stream(self, mutation_list, error_dict): ) def _make_mock_gapic(self, mutation_list, error_dict=None): - mock_fn = CrossSync.Mock() + mock_fn = AsyncMock() if error_dict is None: error_dict = {} mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( @@ -83,7 +83,7 @@ def test_ctor(self): client = mock.Mock() table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] + entries = [_make_mutation(), _make_mutation()] operation_timeout = 0.05 attempt_timeout = 0.01 retryable_exceptions = () @@ -136,14 +136,17 @@ def test_ctor_too_many_entries(self): client = mock.Mock() table = mock.Mock() - entries = [self._make_mutation()] * (_MUTATE_ROWS_REQUEST_MUTATION_LIMIT + 1) + entries = [_make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT operation_timeout = 0.05 attempt_timeout = 0.01 + # no errors if at limit + self._make_one(client, table, entries, operation_timeout, attempt_timeout) + # raise error after crossing with pytest.raises(ValueError) as e: self._make_one( client, table, - entries, + entries + [_make_mutation()], operation_timeout, attempt_timeout, ) @@ -152,18 +155,18 @@ def test_ctor_too_many_entries(self): ) assert "Found 100001" in str(e.value) - @CrossSync.pytest + @pytest.mark.asyncio async def test_mutate_rows_operation(self): """ Test successful case of mutate_rows_operation """ client = mock.Mock() table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] + entries = [_make_mutation(), _make_mutation()] operation_timeout = 0.05 cls = self._target_class() with mock.patch( - f"{cls.__module__}.{cls.__name__}._run_attempt", CrossSync.Mock() + f"{cls.__module__}.{cls.__name__}._run_attempt", AsyncMock() ) as attempt_mock: instance = self._make_one( client, table, entries, operation_timeout, operation_timeout @@ -171,15 +174,17 @@ async def test_mutate_rows_operation(self): await instance.start() assert attempt_mock.call_count == 1 - @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) - @CrossSync.pytest + @pytest.mark.parametrize( + "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] + ) + @pytest.mark.asyncio async def test_mutate_rows_attempt_exception(self, exc_type): """ exceptions raised from attempt should be raised in MutationsExceptionGroup """ - client = CrossSync.Mock() + client = AsyncMock() table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] + entries = [_make_mutation(), _make_mutation()] operation_timeout = 0.05 expected_exception = exc_type("test") client.mutate_rows.side_effect = expected_exception @@ -197,8 +202,10 @@ async def test_mutate_rows_attempt_exception(self, exc_type): assert len(instance.errors) == 2 assert len(instance.remaining_indices) == 0 - @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) - @CrossSync.pytest + @pytest.mark.parametrize( + "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] + ) + @pytest.mark.asyncio async def test_mutate_rows_exception(self, exc_type): """ exceptions raised from retryable should be raised in MutationsExceptionGroup @@ -208,13 +215,13 @@ async def test_mutate_rows_exception(self, exc_type): client = mock.Mock() table = mock.Mock() - entries = [self._make_mutation(), self._make_mutation()] + entries = [_make_mutation(), _make_mutation()] operation_timeout = 0.05 expected_cause = exc_type("abort") with mock.patch.object( self._target_class(), "_run_attempt", - CrossSync.Mock(), + AsyncMock(), ) as attempt_mock: attempt_mock.side_effect = expected_cause found_exc = None @@ -234,24 +241,27 @@ async def test_mutate_rows_exception(self, exc_type): @pytest.mark.parametrize( "exc_type", - [DeadlineExceeded, RuntimeError], + [core_exceptions.DeadlineExceeded, RuntimeError], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): """ If an exception fails but eventually passes, it should not raise an exception """ + from google.cloud.bigtable.data._async._mutate_rows import ( + _MutateRowsOperationAsync, + ) client = mock.Mock() table = mock.Mock() - entries = [self._make_mutation()] + entries = [_make_mutation()] operation_timeout = 1 expected_cause = exc_type("retry") num_retries = 2 with mock.patch.object( - self._target_class(), + _MutateRowsOperationAsync, "_run_attempt", - CrossSync.Mock(), + AsyncMock(), ) as attempt_mock: attempt_mock.side_effect = [expected_cause] * num_retries + [None] instance = self._make_one( @@ -265,7 +275,7 @@ async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): await instance.start() assert attempt_mock.call_count == num_retries + 1 - @CrossSync.pytest + @pytest.mark.asyncio async def test_mutate_rows_incomplete_ignored(self): """ MutateRowsIncomplete exceptions should not be added to error list @@ -276,12 +286,12 @@ async def test_mutate_rows_incomplete_ignored(self): client = mock.Mock() table = mock.Mock() - entries = [self._make_mutation()] + entries = [_make_mutation()] operation_timeout = 0.05 with mock.patch.object( self._target_class(), "_run_attempt", - CrossSync.Mock(), + AsyncMock(), ) as attempt_mock: attempt_mock.side_effect = _MutateRowsIncomplete("ignored") found_exc = None @@ -296,10 +306,10 @@ async def test_mutate_rows_incomplete_ignored(self): assert len(found_exc.exceptions) == 1 assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) - @CrossSync.pytest + @pytest.mark.asyncio async def test_run_attempt_single_entry_success(self): """Test mutating a single entry""" - mutation = self._make_mutation() + mutation = _make_mutation() expected_timeout = 1.3 mock_gapic_fn = self._make_mock_gapic({0: mutation}) instance = self._make_one( @@ -314,7 +324,7 @@ async def test_run_attempt_single_entry_success(self): assert kwargs["timeout"] == expected_timeout assert kwargs["entries"] == [mutation._to_pb()] - @CrossSync.pytest + @pytest.mark.asyncio async def test_run_attempt_empty_request(self): """Calling with no mutations should result in no API calls""" mock_gapic_fn = self._make_mock_gapic([]) @@ -324,14 +334,14 @@ async def test_run_attempt_empty_request(self): await instance._run_attempt() assert mock_gapic_fn.call_count == 0 - @CrossSync.pytest + @pytest.mark.asyncio async def test_run_attempt_partial_success_retryable(self): """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - success_mutation = self._make_mutation() - success_mutation_2 = self._make_mutation() - failure_mutation = self._make_mutation() + success_mutation = _make_mutation() + success_mutation_2 = _make_mutation() + failure_mutation = _make_mutation() mutations = [success_mutation, failure_mutation, success_mutation_2] mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) instance = self._make_one( @@ -347,12 +357,12 @@ async def test_run_attempt_partial_success_retryable(self): assert instance.errors[1][0].grpc_status_code == 300 assert 2 not in instance.errors - @CrossSync.pytest + @pytest.mark.asyncio async def test_run_attempt_partial_success_non_retryable(self): """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" - success_mutation = self._make_mutation() - success_mutation_2 = self._make_mutation() - failure_mutation = self._make_mutation() + success_mutation = _make_mutation() + success_mutation_2 = _make_mutation() + failure_mutation = _make_mutation() mutations = [success_mutation, failure_mutation, success_mutation_2] mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) instance = self._make_one( diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 896c17879..2bf8688fd 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -13,19 +13,23 @@ import pytest -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync # try/except added for compatibility with python < 3.8 try: from unittest import mock + from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER import mock # type: ignore + from mock import AsyncMock # type: ignore # noqa F401 +TEST_FAMILY = "family_name" +TEST_QUALIFIER = b"qualifier" +TEST_TIMESTAMP = 123456789 +TEST_LABELS = ["label1", "label2"] -@CrossSync.export_sync( - path="tests.unit.data._sync.test__read_rows.TestReadRowsOperation", -) -class TestReadRowsOperationAsync: + +class TestReadRowsOperation: """ Tests helper functions in the ReadRowsOperation class in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt @@ -33,9 +37,10 @@ class TestReadRowsOperationAsync: """ @staticmethod - @CrossSync.convert def _get_target_class(): - return CrossSync._ReadRowsOperation + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + + return _ReadRowsOperationAsync def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) @@ -55,9 +60,8 @@ def test_ctor(self): expected_operation_timeout = 42 expected_request_timeout = 44 time_gen_mock = mock.Mock() - subpath = "_async" if CrossSync.is_async else "_sync" with mock.patch( - f"google.cloud.bigtable.data.{subpath}._read_rows._attempt_timeout_generator", + "google.cloud.bigtable.data._async._read_rows._attempt_timeout_generator", time_gen_mock, ): instance = self._make_one( @@ -238,7 +242,7 @@ def test_revise_to_empty_rowset(self): (4, 2, 2), ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_revise_limit(self, start_limit, emit_num, expected_limit): """ revise_limit should revise the request's limit field @@ -279,7 +283,7 @@ async def mock_stream(): assert instance._remaining_count == expected_limit @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) - @CrossSync.pytest + @pytest.mark.asyncio async def test_revise_limit_over_limit(self, start_limit, emit_num): """ Should raise runtime error if we get in state where emit_num > start_num @@ -318,11 +322,7 @@ async def mock_stream(): pass assert "emit count exceeds row limit" in str(e.value) - @CrossSync.pytest - @CrossSync.convert( - sync_name="test_close", - replace_symbols={"aclose": "close", "__anext__": "__next__"}, - ) + @pytest.mark.asyncio async def test_aclose(self): """ should be able to close a stream safely with aclose. @@ -334,7 +334,7 @@ async def mock_stream(): yield 1 with mock.patch.object( - self._get_target_class(), "_read_rows_attempt" + _ReadRowsOperationAsync, "_read_rows_attempt" ) as mock_attempt: instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) wrapped_gen = mock_stream() @@ -343,20 +343,20 @@ async def mock_stream(): # read one row await gen.__anext__() await gen.aclose() - with pytest.raises(CrossSync.StopIteration): + with pytest.raises(StopAsyncIteration): await gen.__anext__() # try calling a second time await gen.aclose() # ensure close was propagated to wrapped generator - with pytest.raises(CrossSync.StopIteration): + with pytest.raises(StopAsyncIteration): await wrapped_gen.__anext__() - @CrossSync.pytest - @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) + @pytest.mark.asyncio async def test_retryable_ignore_repeated_rows(self): """ Duplicate rows should cause an invalid chunk error """ + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable_v2.types import ReadRowsResponse @@ -381,10 +381,37 @@ async def mock_stream(): instance = mock.Mock() instance._last_yielded_row_key = None instance._remaining_count = None - stream = self._get_target_class().chunk_stream( - instance, mock_awaitable_stream() - ) + stream = _ReadRowsOperationAsync.chunk_stream(instance, mock_awaitable_stream()) await stream.__anext__() with pytest.raises(InvalidChunk) as exc: await stream.__anext__() assert "row keys should be strictly increasing" in str(exc.value) + + +class MockStream(_ReadRowsOperationAsync): + """ + Mock a _ReadRowsOperationAsync stream for testing + """ + + def __init__(self, items=None, errors=None, operation_timeout=None): + self.transient_errors = errors + self.operation_timeout = operation_timeout + self.next_idx = 0 + if items is None: + items = list(range(10)) + self.items = items + + def __aiter__(self): + return self + + async def __anext__(self): + if self.next_idx >= len(self.items): + raise StopAsyncIteration + item = self.items[self.next_idx] + self.next_idx += 1 + if isinstance(item, Exception): + raise item + return item + + async def aclose(self): + pass diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index b51987c5d..9ebc403ce 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -32,62 +32,57 @@ from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - # try/except added for compatibility with python < 3.8 try: from unittest import mock + from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER import mock # type: ignore + from mock import AsyncMock # type: ignore -if CrossSync.is_async: - from google.api_core import grpc_helpers_async - from google.cloud.bigtable.data._async.client import TableAsync +VENEER_HEADER_REGEX = re.compile( + r"gapic\/[0-9]+\.[\w.-]+ gax\/[0-9]+\.[\w.-]+ gccl\/[0-9]+\.[\w.-]+-data-async gl-python\/[0-9]+\.[\w.-]+ grpc\/[0-9]+\.[\w.-]+" +) - CrossSync.add_mapping("grpc_helpers", grpc_helpers_async) + +def _make_client(*args, use_emulator=True, **kwargs): + import os + from google.cloud.bigtable.data._async.client import BigtableDataClientAsync + + env_mask = {} + # by default, use emulator mode to avoid auth issues in CI + # emulator mode must be disabled by tests that check channel pooling/refresh background tasks + if use_emulator: + env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" + else: + # set some default values + kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) + kwargs["project"] = kwargs.get("project", "project-id") + with mock.patch.dict(os.environ, env_mask): + return BigtableDataClientAsync(*args, **kwargs) -@CrossSync.export_sync( - path="tests.unit.data._sync.test_client.TestBigtableDataClient", - add_mapping_for_name="TestBigtableDataClient", -) class TestBigtableDataClientAsync: - @staticmethod - @CrossSync.convert - def _get_target_class(): - return CrossSync.DataClient - - @classmethod - def _make_client(cls, *args, use_emulator=True, **kwargs): - import os - - env_mask = {} - # by default, use emulator mode to avoid auth issues in CI - # emulator mode must be disabled by tests that check channel pooling/refresh background tasks - if use_emulator: - env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" - import warnings - - warnings.filterwarnings("ignore", category=RuntimeWarning) - else: - # set some default values - kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) - kwargs["project"] = kwargs.get("project", "project-id") - with mock.patch.dict(os.environ, env_mask): - return cls._get_target_class()(*args, **kwargs) - - @CrossSync.pytest + def _get_target_class(self): + from google.cloud.bigtable.data._async.client import BigtableDataClientAsync + + return BigtableDataClientAsync + + def _make_one(self, *args, **kwargs): + return _make_client(*args, **kwargs) + + @pytest.mark.asyncio async def test_ctor(self): expected_project = "project-id" expected_pool_size = 11 expected_credentials = AnonymousCredentials() - client = self._make_client( + client = self._make_one( project="project-id", pool_size=expected_pool_size, credentials=expected_credentials, use_emulator=False, ) - await CrossSync.yield_to_event_loop() + await asyncio.sleep(0) assert client.project == expected_project assert len(client.transport._grpc_channel._pool) == expected_pool_size assert not client._active_instances @@ -95,29 +90,28 @@ async def test_ctor(self): assert client.transport._credentials == expected_credentials await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test_ctor_super_inits(self): + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) from google.cloud.client import ClientWithProject from google.api_core import client_options as client_options_lib - from google.cloud.bigtable import __version__ as bigtable_version project = "project-id" pool_size = 11 credentials = AnonymousCredentials() client_options = {"api_endpoint": "foo.bar:1234"} options_parsed = client_options_lib.from_dict(client_options) - asyncio_portion = "-async" if CrossSync.is_async else "" - transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" - with mock.patch.object( - CrossSync.GapicClient, "__init__" - ) as bigtable_client_init: + transport_str = f"pooled_grpc_asyncio_{pool_size}" + with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: bigtable_client_init.return_value = None with mock.patch.object( ClientWithProject, "__init__" ) as client_project_init: client_project_init.return_value = None try: - self._make_client( + self._make_one( project=project, pool_size=pool_size, credentials=credentials, @@ -139,16 +133,17 @@ async def test_ctor_super_inits(self): assert kwargs["credentials"] == credentials assert kwargs["client_options"] == options_parsed - @CrossSync.pytest + @pytest.mark.asyncio async def test_ctor_dict_options(self): + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) from google.api_core.client_options import ClientOptions client_options = {"api_endpoint": "foo.bar:1234"} - with mock.patch.object( - CrossSync.GapicClient, "__init__" - ) as bigtable_client_init: + with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: try: - self._make_client(client_options=client_options) + self._make_one(client_options=client_options) except TypeError: pass bigtable_client_init.assert_called_once() @@ -159,29 +154,17 @@ async def test_ctor_dict_options(self): with mock.patch.object( self._get_target_class(), "_start_background_channel_refresh" ) as start_background_refresh: - client = self._make_client( - client_options=client_options, use_emulator=False - ) + client = self._make_one(client_options=client_options, use_emulator=False) start_background_refresh.assert_called_once() await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test_veneer_grpc_headers(self): - client_component = "data-async" if CrossSync.is_async else "data" - VENEER_HEADER_REGEX = re.compile( - r"gapic\/[0-9]+\.[\w.-]+ gax\/[0-9]+\.[\w.-]+ gccl\/[0-9]+\.[\w.-]+-" - + client_component - + r" gl-python\/[0-9]+\.[\w.-]+ grpc\/[0-9]+\.[\w.-]+" - ) - # client_info should be populated with headers to # detect as a veneer client - if CrossSync.is_async: - patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") - else: - patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") + patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") with patch as gapic_mock: - client = self._make_client(project="project-id") + client = self._make_one(project="project-id") wrapped_call_list = gapic_mock.call_args_list assert len(wrapped_call_list) > 0 # each wrapped call should have veneer headers @@ -196,27 +179,33 @@ async def test_veneer_grpc_headers(self): ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test_channel_pool_creation(self): pool_size = 14 - with mock.patch.object( - CrossSync.grpc_helpers, "create_channel", CrossSync.Mock() + with mock.patch( + "google.api_core.grpc_helpers_async.create_channel" ) as create_channel: - client = self._make_client(project="project-id", pool_size=pool_size) + create_channel.return_value = AsyncMock() + client = self._make_one(project="project-id", pool_size=pool_size) assert create_channel.call_count == pool_size await client.close() # channels should be unique - client = self._make_client(project="project-id", pool_size=pool_size) + client = self._make_one(project="project-id", pool_size=pool_size) pool_list = list(client.transport._grpc_channel._pool) pool_set = set(client.transport._grpc_channel._pool) assert len(pool_list) == len(pool_set) await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test_channel_pool_rotation(self): + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel, + ) + pool_size = 7 - with mock.patch.object(CrossSync.PooledChannel, "next_channel") as next_channel: - client = self._make_client(project="project-id", pool_size=pool_size) + + with mock.patch.object(PooledChannel, "next_channel") as next_channel: + client = self._make_one(project="project-id", pool_size=pool_size) assert len(client.transport._grpc_channel._pool) == pool_size next_channel.reset_mock() with mock.patch.object( @@ -235,30 +224,25 @@ async def test_channel_pool_rotation(self): unary_unary.reset_mock() await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test_channel_pool_replace(self): - import time - - sleep_module = asyncio if CrossSync.is_async else time - with mock.patch.object(sleep_module, "sleep"): + with mock.patch.object(asyncio, "sleep"): pool_size = 7 - client = self._make_client(project="project-id", pool_size=pool_size) + client = self._make_one(project="project-id", pool_size=pool_size) for replace_idx in range(pool_size): start_pool = [ channel for channel in client.transport._grpc_channel._pool ] grace_period = 9 with mock.patch.object( - type(client.transport._grpc_channel._pool[-1]), "close" + type(client.transport._grpc_channel._pool[0]), "close" ) as close: - new_channel = client.transport.create_channel() + new_channel = grpc.aio.insecure_channel("localhost:8080") await client.transport.replace_channel( replace_idx, grace=grace_period, new_channel=new_channel ) - close.assert_called_once() - if CrossSync.is_async: - close.assert_called_once_with(grace=grace_period) - close.assert_awaited_once() + close.assert_called_once_with(grace=grace_period) + close.assert_awaited_once() assert client.transport._grpc_channel._pool[replace_idx] == new_channel for i in range(pool_size): if i != replace_idx: @@ -267,59 +251,50 @@ async def test_channel_pool_replace(self): assert client.transport._grpc_channel._pool[i] != start_pool[i] await client.close() - @CrossSync.drop_method @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test__start_background_channel_refresh_sync(self): # should raise RuntimeError if called in a sync context - client = self._make_client(project="project-id", use_emulator=False) + client = self._make_one(project="project-id", use_emulator=False) with pytest.raises(RuntimeError): client._start_background_channel_refresh() - @CrossSync.pytest + @pytest.mark.asyncio async def test__start_background_channel_refresh_tasks_exist(self): # if tasks exist, should do nothing - client = self._make_client(project="project-id", use_emulator=False) + client = self._make_one(project="project-id", use_emulator=False) assert len(client._channel_refresh_tasks) > 0 with mock.patch.object(asyncio, "create_task") as create_task: client._start_background_channel_refresh() create_task.assert_not_called() await client.close() - @CrossSync.pytest + @pytest.mark.asyncio @pytest.mark.parametrize("pool_size", [1, 3, 7]) async def test__start_background_channel_refresh(self, pool_size): - import concurrent.futures - # should create background tasks for each channel - with mock.patch.object( - self._get_target_class(), "_ping_and_warm_instances", CrossSync.Mock() - ) as ping_and_warm: - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - client._start_background_channel_refresh() - assert len(client._channel_refresh_tasks) == pool_size - for task in client._channel_refresh_tasks: - if CrossSync.is_async: - assert isinstance(task, asyncio.Task) - else: - assert isinstance(task, concurrent.futures.Future) - if CrossSync.is_async: - await asyncio.sleep(0.1) - assert ping_and_warm.call_count == pool_size - for channel in client.transport._grpc_channel._pool: - ping_and_warm.assert_any_call(channel) + client = self._make_one( + project="project-id", pool_size=pool_size, use_emulator=False + ) + ping_and_warm = AsyncMock() + client._ping_and_warm_instances = ping_and_warm + client._start_background_channel_refresh() + assert len(client._channel_refresh_tasks) == pool_size + for task in client._channel_refresh_tasks: + assert isinstance(task, asyncio.Task) + await asyncio.sleep(0.1) + assert ping_and_warm.call_count == pool_size + for channel in client.transport._grpc_channel._pool: + ping_and_warm.assert_any_call(channel) await client.close() - @CrossSync.drop_method - @CrossSync.pytest + @pytest.mark.asyncio @pytest.mark.skipif( sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" ) async def test__start_background_channel_refresh_tasks_names(self): # if tasks exist, should do nothing pool_size = 3 - client = self._make_client( + client = self._make_one( project="project-id", pool_size=pool_size, use_emulator=False ) for i in range(pool_size): @@ -328,22 +303,15 @@ async def test__start_background_channel_refresh_tasks_names(self): assert "BigtableDataClientAsync channel refresh " in name await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test__ping_and_warm_instances(self): """ test ping and warm with mocked asyncio.gather """ client_mock = mock.Mock() - client_mock._execute_ping_and_warms = ( - lambda *args: self._get_target_class()._execute_ping_and_warms( - client_mock, *args - ) - ) - with mock.patch.object( - CrossSync, "gather_partials", CrossSync.Mock() - ) as gather: - # gather_partials is expected to call the function passed, and return the result - gather.side_effect = lambda partials, **kwargs: [None for _ in partials] + with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: + # simulate gather by returning the same number of items as passed in + gather.side_effect = lambda *args, **kwargs: [None for _ in args] channel = mock.Mock() # test with no instances client_mock._active_instances = [] @@ -351,8 +319,10 @@ async def test__ping_and_warm_instances(self): client_mock, channel ) assert len(result) == 0 - assert gather.call_args.kwargs["return_exceptions"] is True - assert gather.call_args.kwargs["sync_executor"] == client_mock._executor + gather.assert_called_once() + gather.assert_awaited_once() + assert not gather.call_args.args + assert gather.call_args.kwargs == {"return_exceptions": True} # test with instances client_mock._active_instances = [ (mock.Mock(), mock.Mock(), mock.Mock()) @@ -364,11 +334,8 @@ async def test__ping_and_warm_instances(self): ) assert len(result) == 4 gather.assert_called_once() - # expect one partial for each instance - partial_list = gather.call_args.args[0] - assert len(partial_list) == 4 - if CrossSync.is_async: - gather.assert_awaited_once() + gather.assert_awaited_once() + assert len(gather.call_args.args) == 4 # check grpc call arguments grpc_call_args = channel.unary_unary().call_args_list for idx, (_, kwargs) in enumerate(grpc_call_args): @@ -388,21 +355,15 @@ async def test__ping_and_warm_instances(self): == f"name={expected_instance}&app_profile_id={expected_app_profile}" ) - @CrossSync.pytest + @pytest.mark.asyncio async def test__ping_and_warm_single_instance(self): """ should be able to call ping and warm with single instance """ client_mock = mock.Mock() - client_mock._execute_ping_and_warms = ( - lambda *args: self._get_target_class()._execute_ping_and_warms( - client_mock, *args - ) - ) - with mock.patch.object( - CrossSync, "gather_partials", CrossSync.Mock() - ) as gather: - gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] + with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: + # simulate gather by returning the same number of items as passed in + gather.side_effect = lambda *args, **kwargs: [None for _ in args] channel = mock.Mock() # test with large set of instances client_mock._active_instances = [mock.Mock()] * 100 @@ -426,7 +387,7 @@ async def test__ping_and_warm_single_instance(self): metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" ) - @CrossSync.pytest + @pytest.mark.asyncio @pytest.mark.parametrize( "refresh_interval, wait_time, expected_sleep", [ @@ -444,46 +405,41 @@ async def test__manage_channel_first_sleep( # first sleep time should be `refresh_interval` seconds after client init import time - with mock.patch.object(time, "monotonic") as monotonic: - monotonic.return_value = 0 - with mock.patch.object(CrossSync, "event_wait") as sleep: + with mock.patch.object(time, "monotonic") as time: + time.return_value = 0 + with mock.patch.object(asyncio, "sleep") as sleep: sleep.side_effect = asyncio.CancelledError try: - client = self._make_client(project="project-id") + client = self._make_one(project="project-id") client._channel_init_time = -wait_time await client._manage_channel(0, refresh_interval, refresh_interval) except asyncio.CancelledError: pass sleep.assert_called_once() - call_time = sleep.call_args[0][1] + call_time = sleep.call_args[0][0] assert ( abs(call_time - expected_sleep) < 0.1 ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test__manage_channel_ping_and_warm(self): """ _manage channel should call ping and warm internally """ import time - import threading client_mock = mock.Mock() - client_mock._is_closed.is_set.return_value = False client_mock._channel_init_time = time.monotonic() channel_list = [mock.Mock(), mock.Mock()] client_mock.transport.channels = channel_list new_channel = mock.Mock() client_mock.transport.grpc_channel._create_channel.return_value = new_channel # should ping an warm all new channels, and old channels if sleeping - sleep_tuple = ( - (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple): + with mock.patch.object(asyncio, "sleep"): # stop process after replace_channel is called client_mock.transport.replace_channel.side_effect = asyncio.CancelledError - ping_and_warm = client_mock._ping_and_warm_instances = CrossSync.Mock() + ping_and_warm = client_mock._ping_and_warm_instances = AsyncMock() # should ping and warm old channel then new if sleep > 0 try: channel_idx = 1 @@ -510,7 +466,7 @@ async def test__manage_channel_ping_and_warm(self): pass ping_and_warm.assert_called_once_with(new_channel) - @CrossSync.pytest + @pytest.mark.asyncio @pytest.mark.parametrize( "refresh_interval, num_cycles, expected_sleep", [ @@ -525,59 +481,43 @@ async def test__manage_channel_sleeps( # make sure that sleeps work as expected import time import random - import threading channel_idx = 1 with mock.patch.object(random, "uniform") as uniform: uniform.side_effect = lambda min_, max_: min_ - with mock.patch.object(time, "time") as time_mock: - time_mock.return_value = 0 - sleep_tuple = ( - (asyncio, "sleep") - if CrossSync.is_async - else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: + with mock.patch.object(time, "time") as time: + time.return_value = 0 + with mock.patch.object(asyncio, "sleep") as sleep: sleep.side_effect = [None for i in range(num_cycles - 1)] + [ asyncio.CancelledError ] - client = self._make_client(project="project-id") - with mock.patch.object(client.transport, "replace_channel"): - try: - if refresh_interval is not None: - await client._manage_channel( - channel_idx, refresh_interval, refresh_interval - ) - else: - await client._manage_channel(channel_idx) - except asyncio.CancelledError: - pass + try: + client = self._make_one(project="project-id") + if refresh_interval is not None: + await client._manage_channel( + channel_idx, refresh_interval, refresh_interval + ) + else: + await client._manage_channel(channel_idx) + except asyncio.CancelledError: + pass assert sleep.call_count == num_cycles - if CrossSync.is_async: - total_sleep = sum([call[0][0] for call in sleep.call_args_list]) - else: - total_sleep = sum( - [call[1]["timeout"] for call in sleep.call_args_list] - ) + total_sleep = sum([call[0][0] for call in sleep.call_args_list]) assert ( abs(total_sleep - expected_sleep) < 0.1 ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test__manage_channel_random(self): import random - import threading - sleep_tuple = ( - (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: + with mock.patch.object(asyncio, "sleep") as sleep: with mock.patch.object(random, "uniform") as uniform: uniform.return_value = 0 try: uniform.side_effect = asyncio.CancelledError - client = self._make_client(project="project-id", pool_size=1) + client = self._make_one(project="project-id", pool_size=1) except asyncio.CancelledError: uniform.side_effect = None uniform.reset_mock() @@ -587,48 +527,41 @@ async def test__manage_channel_random(self): uniform.side_effect = lambda min_, max_: min_ sleep.side_effect = [None, None, asyncio.CancelledError] try: - with mock.patch.object(client.transport, "replace_channel"): - await client._manage_channel(0, min_val, max_val) + await client._manage_channel(0, min_val, max_val) except asyncio.CancelledError: pass - assert uniform.call_count == 3 + assert uniform.call_count == 2 uniform_args = [call[0] for call in uniform.call_args_list] for found_min, found_max in uniform_args: assert found_min == min_val assert found_max == max_val - @CrossSync.pytest + @pytest.mark.asyncio @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) async def test__manage_channel_refresh(self, num_cycles): # make sure that channels are properly refreshed - import threading + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, + ) + from google.api_core import grpc_helpers_async expected_grace = 9 expected_refresh = 0.5 channel_idx = 1 - grpc_lib = grpc.aio if CrossSync.is_async else grpc - new_channel = grpc_lib.insecure_channel("localhost:8080") + new_channel = grpc.aio.insecure_channel("localhost:8080") with mock.patch.object( - CrossSync.PooledTransport, "replace_channel" + PooledBigtableGrpcAsyncIOTransport, "replace_channel" ) as replace_channel: - sleep_tuple = ( - (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: + with mock.patch.object(asyncio, "sleep") as sleep: sleep.side_effect = [None for i in range(num_cycles)] + [ asyncio.CancelledError ] with mock.patch.object( - CrossSync.grpc_helpers, "create_channel" + grpc_helpers_async, "create_channel" ) as create_channel: create_channel.return_value = new_channel - with mock.patch.object( - self._get_target_class(), "_start_background_channel_refresh" - ): - client = self._make_client( - project="project-id", use_emulator=False - ) + client = self._make_one(project="project-id", use_emulator=False) create_channel.reset_mock() try: await client._manage_channel( @@ -649,7 +582,7 @@ async def test__manage_channel_refresh(self, num_cycles): assert kwargs["new_channel"] == new_channel await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test__register_instance(self): """ test instance registration @@ -667,7 +600,7 @@ async def test__register_instance(self): ) mock_channels = [mock.Mock() for i in range(5)] client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = CrossSync.Mock() + client_mock._ping_and_warm_instances = AsyncMock() table_mock = mock.Mock() await self._get_target_class()._register_instance( client_mock, "instance-1", table_mock @@ -720,7 +653,7 @@ async def test__register_instance(self): ] ) - @CrossSync.pytest + @pytest.mark.asyncio @pytest.mark.parametrize( "insert_instances,expected_active,expected_owner_keys", [ @@ -753,7 +686,7 @@ async def test__register_instance_state( ) mock_channels = [mock.Mock() for i in range(5)] client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = CrossSync.Mock() + client_mock._ping_and_warm_instances = AsyncMock() table_mock = mock.Mock() # register instances for instance, table, profile in insert_instances: @@ -779,9 +712,9 @@ async def test__register_instance_state( ] ) - @CrossSync.pytest + @pytest.mark.asyncio async def test__remove_instance_registration(self): - client = self._make_client(project="project-id") + client = self._make_one(project="project-id") table = mock.Mock() await client._register_instance("instance-1", table) await client._register_instance("instance-2", table) @@ -810,16 +743,16 @@ async def test__remove_instance_registration(self): assert len(client._active_instances) == 1 await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test__multiple_table_registration(self): """ registering with multiple tables with the same key should add multiple owners to instance_owners, but only keep one copy of shared key in active_instances """ - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey - async with self._make_client(project="project-id") as client: + async with self._make_one(project="project-id") as client: async with client.get_table("instance_1", "table_1") as table_1: instance_1_path = client._gapic_client.instance_path( client.project, "instance_1" @@ -832,20 +765,12 @@ async def test__multiple_table_registration(self): assert id(table_1) in client._instance_owners[instance_1_key] # duplicate table should register in instance_owners under same key async with client.get_table("instance_1", "table_1") as table_2: - assert table_2._register_instance_future is not None - if not CrossSync.is_async: - # give the background task time to run - table_2._register_instance_future.result() assert len(client._instance_owners[instance_1_key]) == 2 assert len(client._active_instances) == 1 assert id(table_1) in client._instance_owners[instance_1_key] assert id(table_2) in client._instance_owners[instance_1_key] # unique table should register in instance_owners and active_instances async with client.get_table("instance_1", "table_3") as table_3: - assert table_3._register_instance_future is not None - if not CrossSync.is_async: - # give the background task time to run - table_3._register_instance_future.result() instance_3_path = client._gapic_client.instance_path( client.project, "instance_1" ) @@ -867,25 +792,17 @@ async def test__multiple_table_registration(self): assert instance_1_key not in client._active_instances assert len(client._instance_owners[instance_1_key]) == 0 - @CrossSync.pytest + @pytest.mark.asyncio async def test__multiple_instance_registration(self): """ registering with multiple instance keys should update the key in instance_owners and active_instances """ - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey - async with self._make_client(project="project-id") as client: + async with self._make_one(project="project-id") as client: async with client.get_table("instance_1", "table_1") as table_1: - assert table_1._register_instance_future is not None - if not CrossSync.is_async: - # give the background task time to run - table_1._register_instance_future.result() async with client.get_table("instance_2", "table_2") as table_2: - assert table_2._register_instance_future is not None - if not CrossSync.is_async: - # give the background task time to run - table_2._register_instance_future.result() instance_1_path = client._gapic_client.instance_path( client.project, "instance_1" ) @@ -914,11 +831,12 @@ async def test__multiple_instance_registration(self): assert len(client._instance_owners[instance_1_key]) == 0 assert len(client._instance_owners[instance_2_key]) == 0 - @CrossSync.pytest + @pytest.mark.asyncio async def test_get_table(self): - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey - client = self._make_client(project="project-id") + client = self._make_one(project="project-id") assert not client._active_instances expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -928,8 +846,8 @@ async def test_get_table(self): expected_table_id, expected_app_profile_id, ) - await CrossSync.yield_to_event_loop() - assert isinstance(table, CrossSync.TestTable._get_target_class()) + await asyncio.sleep(0) + assert isinstance(table, TableAsync) assert table.table_id == expected_table_id assert ( table.table_name @@ -949,14 +867,14 @@ async def test_get_table(self): assert client._instance_owners[instance_key] == {id(table)} await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test_get_table_arg_passthrough(self): """ All arguments passed in get_table should be sent to constructor """ - async with self._make_client(project="project-id") as client: - with mock.patch.object( - CrossSync.TestTable._get_target_class(), "__init__" + async with self._make_one(project="project-id") as client: + with mock.patch( + "google.cloud.bigtable.data._async.client.TableAsync.__init__", ) as mock_constructor: mock_constructor.return_value = None assert not client._active_instances @@ -982,26 +900,25 @@ async def test_get_table_arg_passthrough(self): **expected_kwargs, ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_get_table_context_manager(self): - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey expected_table_id = "table-id" expected_instance_id = "instance-id" expected_app_profile_id = "app-profile-id" expected_project_id = "project-id" - with mock.patch.object( - CrossSync.TestTable._get_target_class(), "close" - ) as close_mock: - async with self._make_client(project=expected_project_id) as client: + with mock.patch.object(TableAsync, "close") as close_mock: + async with self._make_one(project=expected_project_id) as client: async with client.get_table( expected_instance_id, expected_table_id, expected_app_profile_id, ) as table: - await CrossSync.yield_to_event_loop() - assert isinstance(table, CrossSync.TestTable._get_target_class()) + await asyncio.sleep(0) + assert isinstance(table, TableAsync) assert table.table_id == expected_table_id assert ( table.table_name @@ -1021,16 +938,16 @@ async def test_get_table_context_manager(self): assert client._instance_owners[instance_key] == {id(table)} assert close_mock.call_count == 1 - @CrossSync.pytest + @pytest.mark.asyncio async def test_multiple_pool_sizes(self): # should be able to create multiple clients with different pool sizes without issue pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] for pool_size in pool_sizes: - client = self._make_client( + client = self._make_one( project="project-id", pool_size=pool_size, use_emulator=False ) assert len(client._channel_refresh_tasks) == pool_size - client_duplicate = self._make_client( + client_duplicate = self._make_one( project="project-id", pool_size=pool_size, use_emulator=False ) assert len(client_duplicate._channel_refresh_tasks) == pool_size @@ -1038,10 +955,14 @@ async def test_multiple_pool_sizes(self): await client.close() await client_duplicate.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test_close(self): + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, + ) + pool_size = 7 - client = self._make_client( + client = self._make_one( project="project-id", pool_size=pool_size, use_emulator=False ) assert len(client._channel_refresh_tasks) == pool_size @@ -1049,36 +970,36 @@ async def test_close(self): for task in client._channel_refresh_tasks: assert not task.done() with mock.patch.object( - CrossSync.PooledTransport, "close", CrossSync.Mock() + PooledBigtableGrpcAsyncIOTransport, "close", AsyncMock() ) as close_mock: await client.close() close_mock.assert_called_once() - if CrossSync.is_async: - close_mock.assert_awaited() + close_mock.assert_awaited() for task in tasks_list: assert task.done() + assert task.cancelled() + assert client._channel_refresh_tasks == [] - @CrossSync.pytest + @pytest.mark.asyncio async def test_close_with_timeout(self): pool_size = 7 expected_timeout = 19 - client = self._make_client(project="project-id", pool_size=pool_size) + client = self._make_one(project="project-id", pool_size=pool_size) tasks = list(client._channel_refresh_tasks) - with mock.patch.object(CrossSync, "wait", CrossSync.Mock()) as wait_for_mock: + with mock.patch.object(asyncio, "wait_for", AsyncMock()) as wait_for_mock: await client.close(timeout=expected_timeout) wait_for_mock.assert_called_once() - if CrossSync.is_async: - wait_for_mock.assert_awaited() + wait_for_mock.assert_awaited() assert wait_for_mock.call_args[1]["timeout"] == expected_timeout client._channel_refresh_tasks = tasks await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test_context_manager(self): # context manager should close the client cleanly - close_mock = CrossSync.Mock() + close_mock = AsyncMock() true_close = None - async with self._make_client(project="project-id") as client: + async with self._make_one(project="project-id") as client: true_close = client.close() client.close = close_mock for task in client._channel_refresh_tasks: @@ -1087,17 +1008,15 @@ async def test_context_manager(self): assert client._active_instances == set() close_mock.assert_not_called() close_mock.assert_called_once() - if CrossSync.is_async: - close_mock.assert_awaited() + close_mock.assert_awaited() # actually close the client await true_close - @CrossSync.drop_method def test_client_ctor_sync(self): # initializing client in a sync context should raise RuntimeError with pytest.warns(RuntimeWarning) as warnings: - client = self._make_client(project="project-id", use_emulator=False) + client = _make_client(project="project-id", use_emulator=False) expected_warning = [w for w in warnings if "client.py" in w.filename] assert len(expected_warning) == 1 assert ( @@ -1108,22 +1027,11 @@ def test_client_ctor_sync(self): assert client._channel_refresh_tasks == [] -@CrossSync.export_sync( - path="tests.unit.data._sync.test_client.TestTable", add_mapping_for_name="TestTable" -) class TestTableAsync: - @CrossSync.convert - def _make_client(self, *args, **kwargs): - return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) - - @staticmethod - @CrossSync.convert - def _get_target_class(): - return CrossSync.Table - - @CrossSync.pytest + @pytest.mark.asyncio async def test_table_ctor(self): - from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async.client import _WarmedInstanceKey expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -1134,10 +1042,10 @@ async def test_table_ctor(self): expected_read_rows_attempt_timeout = 0.5 expected_mutate_rows_operation_timeout = 2.5 expected_mutate_rows_attempt_timeout = 0.75 - client = self._make_client() + client = _make_client() assert not client._active_instances - table = self._get_target_class()( + table = TableAsync( client, expected_instance_id, expected_table_id, @@ -1149,7 +1057,7 @@ async def test_table_ctor(self): default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, ) - await CrossSync.yield_to_event_loop() + await asyncio.sleep(0) assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id assert table.app_profile_id == expected_app_profile_id @@ -1178,28 +1086,30 @@ async def test_table_ctor(self): == expected_mutate_rows_attempt_timeout ) # ensure task reaches completion - await table._register_instance_future - assert table._register_instance_future.done() - assert not table._register_instance_future.cancelled() - assert table._register_instance_future.exception() is None + await table._register_instance_task + assert table._register_instance_task.done() + assert not table._register_instance_task.cancelled() + assert table._register_instance_task.exception() is None await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test_table_ctor_defaults(self): """ should provide default timeout values and app_profile_id """ + from google.cloud.bigtable.data._async.client import TableAsync + expected_table_id = "table-id" expected_instance_id = "instance-id" - client = self._make_client() + client = _make_client() assert not client._active_instances - table = self._get_target_class()( + table = TableAsync( client, expected_instance_id, expected_table_id, ) - await CrossSync.yield_to_event_loop() + await asyncio.sleep(0) assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id assert table.app_profile_id is None @@ -1212,12 +1122,14 @@ async def test_table_ctor_defaults(self): assert table.default_mutate_rows_attempt_timeout == 60 await client.close() - @CrossSync.pytest + @pytest.mark.asyncio async def test_table_ctor_invalid_timeout_values(self): """ bad timeout values should raise ValueError """ - client = self._make_client() + from google.cloud.bigtable.data._async.client import TableAsync + + client = _make_client() timeout_pairs = [ ("default_operation_timeout", "default_attempt_timeout"), @@ -1232,67 +1144,68 @@ async def test_table_ctor_invalid_timeout_values(self): ] for operation_timeout, attempt_timeout in timeout_pairs: with pytest.raises(ValueError) as e: - self._get_target_class()(client, "", "", **{attempt_timeout: -1}) + TableAsync(client, "", "", **{attempt_timeout: -1}) assert "attempt_timeout must be greater than 0" in str(e.value) with pytest.raises(ValueError) as e: - self._get_target_class()(client, "", "", **{operation_timeout: -1}) + TableAsync(client, "", "", **{operation_timeout: -1}) assert "operation_timeout must be greater than 0" in str(e.value) await client.close() - @CrossSync.drop_method def test_table_ctor_sync(self): # initializing client in a sync context should raise RuntimeError + from google.cloud.bigtable.data._async.client import TableAsync + client = mock.Mock() with pytest.raises(RuntimeError) as e: TableAsync(client, "instance-id", "table-id") assert e.match("TableAsync must be created within an async event loop context.") - @CrossSync.pytest + @pytest.mark.asyncio # iterate over all retryable rpcs @pytest.mark.parametrize( - "fn_name,fn_args,is_stream,extra_retryables", + "fn_name,fn_args,retry_fn_path,extra_retryables", [ ( "read_rows_stream", (ReadRowsQuery(),), - True, + "google.api_core.retry.retry_target_stream_async", (), ), ( "read_rows", (ReadRowsQuery(),), - True, + "google.api_core.retry.retry_target_stream_async", (), ), ( "read_row", (b"row_key",), - True, + "google.api_core.retry.retry_target_stream_async", (), ), ( "read_rows_sharded", ([ReadRowsQuery()],), - True, + "google.api_core.retry.retry_target_stream_async", (), ), ( "row_exists", (b"row_key",), - True, + "google.api_core.retry.retry_target_stream_async", (), ), - ("sample_row_keys", (), False, ()), + ("sample_row_keys", (), "google.api_core.retry.retry_target_async", ()), ( "mutate_row", (b"row_key", [mock.Mock()]), - False, + "google.api_core.retry.retry_target_async", (), ), ( "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), - False, + ([mutations.RowMutationEntry(b"key", [mock.Mock()])],), + "google.api_core.retry.retry_target_async", (_MutateRowsIncomplete,), ), ], @@ -1327,26 +1240,17 @@ async def test_customizable_retryable_errors( expected_retryables, fn_name, fn_args, - is_stream, + retry_fn_path, extra_retryables, ): """ Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - retry_fn = "retry_target" - if is_stream: - retry_fn += "_stream" - if CrossSync.is_async: - retry_fn = f"CrossSync.{retry_fn}" - else: - retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" - with mock.patch( - f"google.cloud.bigtable.data._sync.cross_sync.{retry_fn}" - ) as retry_fn_mock: - async with self._make_client() as client: + with mock.patch(retry_fn_path) as retry_fn_mock: + async with _make_client() as client: table = client.get_table("instance-id", "table-id") - expected_predicate = expected_retryables.__contains__ + expected_predicate = lambda a: a in expected_retryables # noqa retry_fn_mock.side_effect = RuntimeError("stop early") with mock.patch( "google.api_core.retry.if_exception_type" @@ -1388,19 +1292,18 @@ async def test_customizable_retryable_errors( ], ) @pytest.mark.parametrize("include_app_profile", [True, False]) - @CrossSync.pytest - @CrossSync.convert + @pytest.mark.asyncio async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): """check that all requests attach proper metadata headers""" + from google.cloud.bigtable.data import TableAsync + profile = "profile" if include_app_profile else None - with mock.patch.object( - CrossSync.GapicClient, gapic_fn, CrossSync.Mock() + with mock.patch( + f"google.cloud.bigtable_v2.BigtableAsyncClient.{gapic_fn}", mock.AsyncMock() ) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") - async with self._make_client() as client: - table = self._get_target_class()( - client, "instance-id", "table-id", profile - ) + async with _make_client() as client: + table = TableAsync(client, "instance-id", "table-id", profile) try: test_fn = table.__getattribute__(fn_name) maybe_stream = await test_fn(*fn_args) @@ -1422,32 +1325,20 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ assert "app_profile_id=" not in goog_metadata -@CrossSync.export_sync( - path="tests.unit.data._sync.test_client.TestReadRows", - add_mapping_for_name="TestReadRows", -) -class TestReadRowsAsync: +class TestReadRows: """ Tests for table.read_rows and related methods. """ - @staticmethod - @CrossSync.convert - def _get_operation_class(): - return CrossSync._ReadRowsOperation - - @CrossSync.convert - def _make_client(self, *args, **kwargs): - return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) - - @CrossSync.convert def _make_table(self, *args, **kwargs): + from google.cloud.bigtable.data._async.client import TableAsync + client_mock = mock.Mock() client_mock._register_instance.side_effect = ( - lambda *args, **kwargs: CrossSync.yield_to_event_loop() + lambda *args, **kwargs: asyncio.sleep(0) ) client_mock._remove_instance_registration.side_effect = ( - lambda *args, **kwargs: CrossSync.yield_to_event_loop() + lambda *args, **kwargs: asyncio.sleep(0) ) kwargs["instance_id"] = kwargs.get( "instance_id", args[0] if args else "instance" @@ -1457,7 +1348,7 @@ def _make_table(self, *args, **kwargs): ) client_mock._gapic_client.table_path.return_value = kwargs["table_id"] client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] - return CrossSync.TestTable._get_target_class()(client_mock, *args, **kwargs) + return TableAsync(client_mock, *args, **kwargs) def _make_stats(self): from google.cloud.bigtable_v2.types import RequestStats @@ -1488,7 +1379,6 @@ def _make_chunk(*args, **kwargs): return ReadRowsResponse.CellChunk(*args, **kwargs) @staticmethod - @CrossSync.convert async def _make_gapic_stream( chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0, @@ -1504,34 +1394,27 @@ def __init__(self, chunk_list, sleep_time): def __aiter__(self): return self - def __iter__(self): - return self - async def __anext__(self): self.idx += 1 if len(self.chunk_list) > self.idx: if sleep_time: - await CrossSync.sleep(self.sleep_time) + await asyncio.sleep(self.sleep_time) chunk = self.chunk_list[self.idx] if isinstance(chunk, Exception): raise chunk else: return ReadRowsResponse(chunks=[chunk]) - raise CrossSync.StopIteration - - def __next__(self): - return self.__anext__() + raise StopAsyncIteration def cancel(self): pass return mock_stream(chunk_list, sleep_time) - @CrossSync.convert async def execute_fn(self, table, *args, **kwargs): return await table.read_rows(*args, **kwargs) - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows(self): query = ReadRowsQuery() chunks = [ @@ -1548,7 +1431,7 @@ async def test_read_rows(self): assert results[0].row_key == b"test_1" assert results[1].row_key == b"test_2" - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_stream(self): query = ReadRowsQuery() chunks = [ @@ -1567,7 +1450,7 @@ async def test_read_rows_stream(self): assert results[1].row_key == b"test_2" @pytest.mark.parametrize("include_app_profile", [True, False]) - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_query_matches_request(self, include_app_profile): from google.cloud.bigtable.data import RowRange from google.cloud.bigtable.data.row_filters import PassAllFilter @@ -1594,14 +1477,14 @@ async def test_read_rows_query_matches_request(self, include_app_profile): assert call_request == query_pb @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_timeout(self, operation_timeout): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows query = ReadRowsQuery() chunks = [self._make_chunk(row_key=b"test_1")] read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=0.15 + chunks, sleep_time=1 ) try: await table.read_rows(query, operation_timeout=operation_timeout) @@ -1619,7 +1502,7 @@ async def test_read_rows_timeout(self, operation_timeout): (0.05, 0.24, 5), ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_attempt_timeout( self, per_request_t, operation_t, expected_num ): @@ -1682,7 +1565,7 @@ async def test_read_rows_attempt_timeout( core_exceptions.ServiceUnavailable, ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_retryable_error(self, exc_type): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows @@ -1713,7 +1596,7 @@ async def test_read_rows_retryable_error(self, exc_type): InvalidChunk, ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_non_retryable_error(self, exc_type): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows @@ -1727,17 +1610,18 @@ async def test_read_rows_non_retryable_error(self, exc_type): except exc_type as e: assert e == expected_error - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_revise_request(self): """ Ensure that _revise_request is called between retries """ + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable_v2.types import RowSet return_val = RowSet() with mock.patch.object( - self._get_operation_class(), "_revise_request_rowset" + _ReadRowsOperationAsync, "_revise_request_rowset" ) as revise_rowset: revise_rowset.return_value = return_val async with self._make_table() as table: @@ -1761,14 +1645,16 @@ async def test_read_rows_revise_request(self): revised_call = read_rows.call_args_list[1].args[0] assert revised_call.rows == return_val - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_default_timeouts(self): """ Ensure that the default timeouts are set on the read rows operation when not overridden """ + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + operation_timeout = 8 attempt_timeout = 4 - with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: + with mock.patch.object(_ReadRowsOperationAsync, "__init__") as mock_op: mock_op.side_effect = RuntimeError("mock error") async with self._make_table( default_read_rows_operation_timeout=operation_timeout, @@ -1782,14 +1668,16 @@ async def test_read_rows_default_timeouts(self): assert kwargs["operation_timeout"] == operation_timeout assert kwargs["attempt_timeout"] == attempt_timeout - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_default_timeout_override(self): """ When timeouts are passed, they overwrite default values """ + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + operation_timeout = 8 attempt_timeout = 4 - with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: + with mock.patch.object(_ReadRowsOperationAsync, "__init__") as mock_op: mock_op.side_effect = RuntimeError("mock error") async with self._make_table( default_operation_timeout=99, default_attempt_timeout=97 @@ -1806,10 +1694,10 @@ async def test_read_rows_default_timeout_override(self): assert kwargs["operation_timeout"] == operation_timeout assert kwargs["attempt_timeout"] == attempt_timeout - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_row(self): """Test reading a single row""" - async with self._make_client() as client: + async with _make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1834,10 +1722,10 @@ async def test_read_row(self): assert query.row_ranges == [] assert query.limit == 1 - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_row_w_filter(self): """Test reading a single row with an added filter""" - async with self._make_client() as client: + async with _make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1867,10 +1755,10 @@ async def test_read_row_w_filter(self): assert query.limit == 1 assert query.filter == expected_filter - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_row_no_response(self): """should return None if row does not exist""" - async with self._make_client() as client: + async with _make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1902,10 +1790,10 @@ async def test_read_row_no_response(self): ([object(), object()], True), ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_row_exists(self, return_value, expected_result): """Test checking for row existence""" - async with self._make_client() as client: + async with _make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1939,35 +1827,32 @@ async def test_row_exists(self, return_value, expected_result): assert query.filter._to_dict() == expected_filter -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRowsSharded") -class TestReadRowsShardedAsync: - @CrossSync.convert - def _make_client(self, *args, **kwargs): - return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) - - @CrossSync.pytest +class TestReadRowsSharded: + @pytest.mark.asyncio async def test_read_rows_sharded_empty_query(self): - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as exc: await table.read_rows_sharded([]) assert "empty sharded_query" in str(exc.value) - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_sharded_multiple_queries(self): """ Test with multiple queries. Should return results from both """ - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( table.client._gapic_client, "read_rows" ) as read_rows: - read_rows.side_effect = lambda *args, **kwargs: CrossSync.TestReadRows._make_gapic_stream( - [ - CrossSync.TestReadRows._make_chunk(row_key=k) - for k in args[0].rows.row_keys - ] + read_rows.side_effect = ( + lambda *args, **kwargs: TestReadRows._make_gapic_stream( + [ + TestReadRows._make_chunk(row_key=k) + for k in args[0].rows.row_keys + ] + ) ) query_1 = ReadRowsQuery(b"test_1") query_2 = ReadRowsQuery(b"test_2") @@ -1977,19 +1862,19 @@ async def test_read_rows_sharded_multiple_queries(self): assert result[1].row_key == b"test_2" @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24]) - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_sharded_multiple_queries_calls(self, n_queries): """ Each query should trigger a separate read_rows call """ - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: query_list = [ReadRowsQuery() for _ in range(n_queries)] await table.read_rows_sharded(query_list) assert read_rows.call_count == n_queries - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_sharded_errors(self): """ Errors should be exposed as ShardedReadRowsExceptionGroups @@ -1997,7 +1882,7 @@ async def test_read_rows_sharded_errors(self): from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.cloud.bigtable.data.exceptions import FailedQueryShardError - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = RuntimeError("mock error") @@ -2017,7 +1902,7 @@ async def test_read_rows_sharded_errors(self): assert exc.value.exceptions[1].index == 1 assert exc.value.exceptions[1].query == query_2 - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_sharded_concurrent(self): """ Ensure sharded requests are concurrent @@ -2028,7 +1913,7 @@ async def mock_call(*args, **kwargs): await asyncio.sleep(0.1) return [mock.Mock()] - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -2041,14 +1926,14 @@ async def mock_call(*args, **kwargs): # if run in sequence, we would expect this to take 1 second assert call_time < 0.2 - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_sharded_concurrency_limit(self): """ Only 10 queries should be processed concurrently. Others should be queued Should start a new query as soon as previous finishes """ - from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT + from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT assert _CONCURRENCY_LIMIT == 10 # change this test if this changes num_queries = 15 @@ -2066,7 +1951,7 @@ async def mock_call(*args, **kwargs): starting_timeout = 10 - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -2090,13 +1975,13 @@ async def mock_call(*args, **kwargs): idx = i + _CONCURRENCY_LIMIT assert rpc_start_list[idx] - (i * increment_time) < eps - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_sharded_expirary(self): """ If the operation times out before all shards complete, should raise a ShardedReadRowsExceptionGroup """ - from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT + from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.api_core.exceptions import DeadlineExceeded @@ -2116,7 +2001,7 @@ async def mock_call(*args, **kwargs): await asyncio.sleep(next_item) return [mock.Mock()] - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -2130,7 +2015,7 @@ async def mock_call(*args, **kwargs): # should keep successful queries assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_rows_sharded_negative_batch_timeout(self): """ try to run with batch that starts after operation timeout @@ -2141,10 +2026,10 @@ async def test_read_rows_sharded_negative_batch_timeout(self): from google.api_core.exceptions import DeadlineExceeded async def mock_call(*args, **kwargs): - await CrossSync.sleep(0.05) + await asyncio.sleep(0.05) return [mock.Mock()] - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -2159,20 +2044,14 @@ async def mock_call(*args, **kwargs): ) -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestSampleRowKeys") -class TestSampleRowKeysAsync: - @CrossSync.convert - def _make_client(self, *args, **kwargs): - return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) - - @CrossSync.convert +class TestSampleRowKeys: async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): from google.cloud.bigtable_v2.types import SampleRowKeysResponse for value in sample_list: yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) - @CrossSync.pytest + @pytest.mark.asyncio async def test_sample_row_keys(self): """ Test that method returns the expected key samples @@ -2182,10 +2061,10 @@ async def test_sample_row_keys(self): (b"test_2", 100), (b"test_3", 200), ] - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", CrossSync.Mock() + table.client._gapic_client, "sample_row_keys", AsyncMock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream(samples) result = await table.sample_row_keys() @@ -2197,12 +2076,12 @@ async def test_sample_row_keys(self): assert result[1] == samples[1] assert result[2] == samples[2] - @CrossSync.pytest + @pytest.mark.asyncio async def test_sample_row_keys_bad_timeout(self): """ should raise error if timeout is negative """ - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.sample_row_keys(operation_timeout=-1) @@ -2211,11 +2090,11 @@ async def test_sample_row_keys_bad_timeout(self): await table.sample_row_keys(attempt_timeout=-1) assert "attempt_timeout must be greater than 0" in str(e.value) - @CrossSync.pytest + @pytest.mark.asyncio async def test_sample_row_keys_default_timeout(self): """Should fallback to using table default operation_timeout""" expected_timeout = 99 - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table( "i", "t", @@ -2223,7 +2102,7 @@ async def test_sample_row_keys_default_timeout(self): default_attempt_timeout=expected_timeout, ) as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", CrossSync.Mock() + table.client._gapic_client, "sample_row_keys", AsyncMock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) result = await table.sample_row_keys() @@ -2232,7 +2111,7 @@ async def test_sample_row_keys_default_timeout(self): assert result == [] assert kwargs["retry"] is None - @CrossSync.pytest + @pytest.mark.asyncio async def test_sample_row_keys_gapic_params(self): """ make sure arguments are propagated to gapic call as expected @@ -2241,12 +2120,12 @@ async def test_sample_row_keys_gapic_params(self): expected_profile = "test1" instance = "instance_name" table_id = "my_table" - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table( instance, table_id, app_profile_id=expected_profile ) as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", CrossSync.Mock() + table.client._gapic_client, "sample_row_keys", AsyncMock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) await table.sample_row_keys(attempt_timeout=expected_timeout) @@ -2266,7 +2145,7 @@ async def test_sample_row_keys_gapic_params(self): core_exceptions.ServiceUnavailable, ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_sample_row_keys_retryable_errors(self, retryable_exception): """ retryable errors should be retried until timeout @@ -2274,10 +2153,10 @@ async def test_sample_row_keys_retryable_errors(self, retryable_exception): from google.api_core.exceptions import DeadlineExceeded from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", CrossSync.Mock() + table.client._gapic_client, "sample_row_keys", AsyncMock() ) as sample_row_keys: sample_row_keys.side_effect = retryable_exception("mock") with pytest.raises(DeadlineExceeded) as e: @@ -2298,30 +2177,23 @@ async def test_sample_row_keys_retryable_errors(self, retryable_exception): core_exceptions.Aborted, ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): """ non-retryable errors should cause a raise """ - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", CrossSync.Mock() + table.client._gapic_client, "sample_row_keys", AsyncMock() ) as sample_row_keys: sample_row_keys.side_effect = non_retryable_exception("mock") with pytest.raises(non_retryable_exception): await table.sample_row_keys() -@CrossSync.export_sync( - path="tests.unit.data._sync.test_client.TestMutateRow", -) -class TestMutateRowAsync: - @CrossSync.convert - def _make_client(self, *args, **kwargs): - return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) - - @CrossSync.pytest +class TestMutateRow: + @pytest.mark.asyncio @pytest.mark.parametrize( "mutation_arg", [ @@ -2342,7 +2214,7 @@ def _make_client(self, *args, **kwargs): async def test_mutate_row(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2377,12 +2249,12 @@ async def test_mutate_row(self, mutation_arg): core_exceptions.ServiceUnavailable, ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_mutate_row_retryable_errors(self, retryable_exception): from google.api_core.exceptions import DeadlineExceeded from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2405,14 +2277,14 @@ async def test_mutate_row_retryable_errors(self, retryable_exception): core_exceptions.ServiceUnavailable, ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_mutate_row_non_idempotent_retryable_errors( self, retryable_exception ): """ Non-idempotent mutations should not be retried """ - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2438,9 +2310,9 @@ async def test_mutate_row_non_idempotent_retryable_errors( core_exceptions.Aborted, ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2459,14 +2331,14 @@ async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): ) @pytest.mark.parametrize("include_app_profile", [True, False]) - @CrossSync.pytest + @pytest.mark.asyncio async def test_mutate_row_metadata(self, include_app_profile): """request should attach metadata headers""" profile = "profile" if include_app_profile else None - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("i", "t", app_profile_id=profile) as table: with mock.patch.object( - client._gapic_client, "mutate_row", CrossSync.Mock() + client._gapic_client, "mutate_row", AsyncMock() ) as read_rows: await table.mutate_row("rk", mock.Mock()) kwargs = read_rows.call_args_list[0].kwargs @@ -2483,24 +2355,16 @@ async def test_mutate_row_metadata(self, include_app_profile): assert "app_profile_id=" not in goog_metadata @pytest.mark.parametrize("mutations", [[], None]) - @CrossSync.pytest + @pytest.mark.asyncio async def test_mutate_row_no_mutations(self, mutations): - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.mutate_row("key", mutations=mutations) assert e.value.args[0] == "No mutations provided" -@CrossSync.export_sync( - path="tests.unit.data._sync.test_client.TestBulkMutateRows", -) -class TestBulkMutateRowsAsync: - @CrossSync.convert - def _make_client(self, *args, **kwargs): - return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) - - @CrossSync.convert +class TestBulkMutateRows: async def _mock_response(self, response_list): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -2525,8 +2389,8 @@ async def generator(): return generator() - @CrossSync.pytest - @CrossSync.pytest + @pytest.mark.asyncio + @pytest.mark.asyncio @pytest.mark.parametrize( "mutation_arg", [ @@ -2549,7 +2413,7 @@ async def generator(): async def test_bulk_mutate_rows(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2570,10 +2434,10 @@ async def test_bulk_mutate_rows(self, mutation_arg): assert kwargs["timeout"] == expected_attempt_timeout assert kwargs["retry"] is None - @CrossSync.pytest + @pytest.mark.asyncio async def test_bulk_mutate_rows_multiple_entries(self): """Test mutations with no errors""" - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2594,7 +2458,7 @@ async def test_bulk_mutate_rows_multiple_entries(self): assert kwargs["entries"][0] == entry_1._to_pb() assert kwargs["entries"][1] == entry_2._to_pb() - @CrossSync.pytest + @pytest.mark.asyncio @pytest.mark.parametrize( "exception", [ @@ -2614,7 +2478,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_retryable( MutationsExceptionGroup, ) - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2639,7 +2503,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_retryable( cause.exceptions[-1], core_exceptions.DeadlineExceeded ) - @CrossSync.pytest + @pytest.mark.asyncio @pytest.mark.parametrize( "exception", [ @@ -2660,7 +2524,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable( MutationsExceptionGroup, ) - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2687,7 +2551,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable( core_exceptions.ServiceUnavailable, ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_bulk_mutate_idempotent_retryable_request_errors( self, retryable_exception ): @@ -2700,7 +2564,7 @@ async def test_bulk_mutate_idempotent_retryable_request_errors( MutationsExceptionGroup, ) - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2721,7 +2585,7 @@ async def test_bulk_mutate_idempotent_retryable_request_errors( assert isinstance(cause, RetryExceptionGroup) assert isinstance(cause.exceptions[0], retryable_exception) - @CrossSync.pytest + @pytest.mark.asyncio @pytest.mark.parametrize( "retryable_exception", [ @@ -2738,7 +2602,7 @@ async def test_bulk_mutate_rows_non_idempotent_retryable_errors( MutationsExceptionGroup, ) - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2770,7 +2634,7 @@ async def test_bulk_mutate_rows_non_idempotent_retryable_errors( ValueError, ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception): """ If the request fails with a non-retryable error, mutations should not be retried @@ -2780,7 +2644,7 @@ async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_excepti MutationsExceptionGroup, ) - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2800,7 +2664,7 @@ async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_excepti cause = failed_exception.__cause__ assert isinstance(cause, non_retryable_exception) - @CrossSync.pytest + @pytest.mark.asyncio async def test_bulk_mutate_error_index(self): """ Test partial failure, partial success. Errors should be associated with the correct index @@ -2816,7 +2680,7 @@ async def test_bulk_mutate_error_index(self): MutationsExceptionGroup, ) - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2851,14 +2715,14 @@ async def test_bulk_mutate_error_index(self): assert isinstance(cause.exceptions[1], DeadlineExceeded) assert isinstance(cause.exceptions[2], FailedPrecondition) - @CrossSync.pytest + @pytest.mark.asyncio async def test_bulk_mutate_error_recovery(self): """ If an error occurs, then resolves, no exception should be raised """ from google.api_core.exceptions import DeadlineExceeded - async with self._make_client(project="project") as client: + async with _make_client(project="project") as client: table = client.get_table("instance", "table") with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: # fail with a retryable error, then a non-retryable one @@ -2876,19 +2740,14 @@ async def test_bulk_mutate_error_recovery(self): await table.bulk_mutate_rows(entries, operation_timeout=1000) -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestCheckAndMutateRow") -class TestCheckAndMutateRowAsync: - @CrossSync.convert - def _make_client(self, *args, **kwargs): - return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) - +class TestCheckAndMutateRow: @pytest.mark.parametrize("gapic_result", [True, False]) - @CrossSync.pytest + @pytest.mark.asyncio async def test_check_and_mutate(self, gapic_result): from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse app_profile = "app_profile_id" - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table( "instance", "table", app_profile_id=app_profile ) as table: @@ -2925,10 +2784,10 @@ async def test_check_and_mutate(self, gapic_result): assert kwargs["timeout"] == operation_timeout assert kwargs["retry"] is None - @CrossSync.pytest + @pytest.mark.asyncio async def test_check_and_mutate_bad_timeout(self): """Should raise error if operation_timeout < 0""" - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.check_and_mutate_row( @@ -2940,13 +2799,13 @@ async def test_check_and_mutate_bad_timeout(self): ) assert str(e.value) == "operation_timeout must be greater than 0" - @CrossSync.pytest + @pytest.mark.asyncio async def test_check_and_mutate_single_mutations(self): """if single mutations are passed, they should be internally wrapped in a list""" from google.cloud.bigtable.data.mutations import SetCell from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2966,7 +2825,7 @@ async def test_check_and_mutate_single_mutations(self): assert kwargs["true_mutations"] == [true_mutation._to_pb()] assert kwargs["false_mutations"] == [false_mutation._to_pb()] - @CrossSync.pytest + @pytest.mark.asyncio async def test_check_and_mutate_predicate_object(self): """predicate filter should be passed to gapic request""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse @@ -2974,7 +2833,7 @@ async def test_check_and_mutate_predicate_object(self): mock_predicate = mock.Mock() predicate_pb = {"predicate": "dict"} mock_predicate._to_pb.return_value = predicate_pb - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2992,7 +2851,7 @@ async def test_check_and_mutate_predicate_object(self): assert mock_predicate._to_pb.call_count == 1 assert kwargs["retry"] is None - @CrossSync.pytest + @pytest.mark.asyncio async def test_check_and_mutate_mutations_parsing(self): """mutations objects should be converted to protos""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse @@ -3002,7 +2861,7 @@ async def test_check_and_mutate_mutations_parsing(self): for idx, mutation in enumerate(mutations): mutation._to_pb.return_value = f"fake {idx}" mutations.append(DeleteAllFromRow()) - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -3029,12 +2888,7 @@ async def test_check_and_mutate_mutations_parsing(self): ) -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadModifyWriteRow") -class TestReadModifyWriteRowAsync: - @CrossSync.convert - def _make_client(self, *args, **kwargs): - return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) - +class TestReadModifyWriteRow: @pytest.mark.parametrize( "call_rules,expected_rules", [ @@ -3056,12 +2910,12 @@ def _make_client(self, *args, **kwargs): ), ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): """ Test that the gapic call is called with given rules """ - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -3073,21 +2927,21 @@ async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules assert found_kwargs["retry"] is None @pytest.mark.parametrize("rules", [[], None]) - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_modify_write_no_rules(self, rules): - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.read_modify_write_row("key", rules=rules) assert e.value.args[0] == "rules must contain at least one item" - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_modify_write_call_defaults(self): instance = "instance1" table_id = "table1" project = "project1" row_key = "row_key1" - async with self._make_client(project=project) as client: + async with _make_client(project=project) as client: async with client.get_table(instance, table_id) as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -3103,12 +2957,12 @@ async def test_read_modify_write_call_defaults(self): assert kwargs["row_key"] == row_key.encode() assert kwargs["timeout"] > 1 - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_modify_write_call_overrides(self): row_key = b"row_key1" expected_timeout = 12345 profile_id = "profile1" - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table( "instance", "table_id", app_profile_id=profile_id ) as table: @@ -3126,10 +2980,10 @@ async def test_read_modify_write_call_overrides(self): assert kwargs["row_key"] == row_key assert kwargs["timeout"] == expected_timeout - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_modify_write_string_key(self): row_key = "string_row_key1" - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table_id") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -3139,7 +2993,7 @@ async def test_read_modify_write_string_key(self): kwargs = mock_gapic.call_args_list[0][1] assert kwargs["row_key"] == row_key.encode() - @CrossSync.pytest + @pytest.mark.asyncio async def test_read_modify_write_row_building(self): """ results from gapic call should be used to construct row @@ -3149,7 +3003,7 @@ async def test_read_modify_write_row_building(self): from google.cloud.bigtable_v2.types import Row as RowPB mock_response = ReadModifyWriteRowResponse(row=RowPB()) - async with self._make_client() as client: + async with _make_client() as client: async with client.get_table("instance", "table_id") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index fcd425273..cca7c9824 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -14,39 +14,33 @@ import pytest import asyncio -import time import google.api_core.exceptions as core_exceptions -import google.api_core.retry from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.cloud.bigtable.data import TABLE_DEFAULT -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - # try/except added for compatibility with python < 3.8 try: from unittest import mock + from unittest.mock import AsyncMock except ImportError: # pragma: NO COVER import mock # type: ignore + from mock import AsyncMock # type: ignore -@CrossSync.export_sync( - path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl" -) -class Test_FlowControl: - @staticmethod - @CrossSync.convert - def _target_class(): - return CrossSync._FlowControl +def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + +class Test_FlowControl: def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): - return self._target_class()(max_mutation_count, max_mutation_bytes) + from google.cloud.bigtable.data._async.mutations_batcher import ( + _FlowControlAsync, + ) - @staticmethod - def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation + return _FlowControlAsync(max_mutation_count, max_mutation_bytes) def test_ctor(self): max_mutation_count = 9 @@ -56,7 +50,7 @@ def test_ctor(self): assert instance._max_mutation_bytes == max_mutation_bytes assert instance._in_flight_mutation_count == 0 assert instance._in_flight_mutation_bytes == 0 - assert isinstance(instance._capacity_condition, CrossSync.Condition) + assert isinstance(instance._capacity_condition, asyncio.Condition) def test_ctor_invalid_values(self): """Test that values are positive, and fit within expected limits""" @@ -116,7 +110,7 @@ def test__has_capacity( instance._in_flight_mutation_bytes = existing_size assert instance._has_capacity(new_count, new_size) == expected - @CrossSync.pytest + @pytest.mark.asyncio @pytest.mark.parametrize( "existing_count,existing_size,added_count,added_size,new_count,new_size", [ @@ -144,12 +138,12 @@ async def test_remove_from_flow_value_update( instance = self._make_one() instance._in_flight_mutation_count = existing_count instance._in_flight_mutation_bytes = existing_size - mutation = self._make_mutation(added_count, added_size) + mutation = _make_mutation(added_count, added_size) await instance.remove_from_flow(mutation) assert instance._in_flight_mutation_count == new_count assert instance._in_flight_mutation_bytes == new_size - @CrossSync.pytest + @pytest.mark.asyncio async def test__remove_from_flow_unlock(self): """capacity condition should notify after mutation is complete""" instance = self._make_one(10, 10) @@ -162,50 +156,36 @@ async def task_routine(): lambda: instance._has_capacity(1, 1) ) - if CrossSync.is_async: - # for async class, build task to test flow unlock - task = asyncio.create_task(task_routine()) - - def task_alive(): - return not task.done() - - else: - # this branch will be tested in sync version of this test - import threading - - thread = threading.Thread(target=task_routine) - thread.start() - task_alive = thread.is_alive - await CrossSync.sleep(0.05) + task = asyncio.create_task(task_routine()) + await asyncio.sleep(0.05) # should be blocked due to capacity - assert task_alive() is True + assert task.done() is False # try changing size - mutation = self._make_mutation(count=0, size=5) - + mutation = _make_mutation(count=0, size=5) await instance.remove_from_flow([mutation]) - await CrossSync.sleep(0.05) + await asyncio.sleep(0.05) assert instance._in_flight_mutation_count == 10 assert instance._in_flight_mutation_bytes == 5 - assert task_alive() is True + assert task.done() is False # try changing count instance._in_flight_mutation_bytes = 10 - mutation = self._make_mutation(count=5, size=0) + mutation = _make_mutation(count=5, size=0) await instance.remove_from_flow([mutation]) - await CrossSync.sleep(0.05) + await asyncio.sleep(0.05) assert instance._in_flight_mutation_count == 5 assert instance._in_flight_mutation_bytes == 10 - assert task_alive() is True + assert task.done() is False # try changing both instance._in_flight_mutation_count = 10 - mutation = self._make_mutation(count=5, size=5) + mutation = _make_mutation(count=5, size=5) await instance.remove_from_flow([mutation]) - await CrossSync.sleep(0.05) + await asyncio.sleep(0.05) assert instance._in_flight_mutation_count == 5 assert instance._in_flight_mutation_bytes == 5 # task should be complete - assert task_alive() is False + assert task.done() is True - @CrossSync.pytest + @pytest.mark.asyncio @pytest.mark.parametrize( "mutations,count_cap,size_cap,expected_results", [ @@ -230,7 +210,7 @@ async def test_add_to_flow(self, mutations, count_cap, size_cap, expected_result """ Test batching with various flow control settings """ - mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] + mutation_objs = [_make_mutation(count=m[0], size=m[1]) for m in mutations] instance = self._make_one(count_cap, size_cap) i = 0 async for batch in instance.add_to_flow(mutation_objs): @@ -246,7 +226,7 @@ async def test_add_to_flow(self, mutations, count_cap, size_cap, expected_result i += 1 assert i == len(expected_results) - @CrossSync.pytest + @pytest.mark.asyncio @pytest.mark.parametrize( "mutations,max_limit,expected_results", [ @@ -262,12 +242,11 @@ async def test_add_to_flow_max_mutation_limits( Test flow control running up against the max API limit Should submit request early, even if the flow control has room for more """ - subpath = "_async" if CrossSync.is_async else "_sync" - path = f"google.cloud.bigtable.data.{subpath}.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT" - with mock.patch(path, max_limit): - mutation_objs = [ - self._make_mutation(count=m[0], size=m[1]) for m in mutations - ] + with mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", + max_limit, + ): + mutation_objs = [_make_mutation(count=m[0], size=m[1]) for m in mutations] # flow control has no limits except API restrictions instance = self._make_one(float("inf"), float("inf")) i = 0 @@ -284,14 +263,14 @@ async def test_add_to_flow_max_mutation_limits( i += 1 assert i == len(expected_results) - @CrossSync.pytest + @pytest.mark.asyncio async def test_add_to_flow_oversize(self): """ mutations over the flow control limits should still be accepted """ instance = self._make_one(2, 3) - large_size_mutation = self._make_mutation(count=1, size=10) - large_count_mutation = self._make_mutation(count=10, size=1) + large_size_mutation = _make_mutation(count=1, size=10) + large_count_mutation = _make_mutation(count=10, size=1) results = [out async for out in instance.add_to_flow([large_size_mutation])] assert len(results) == 1 await instance.remove_from_flow(results[0]) @@ -301,13 +280,13 @@ async def test_add_to_flow_oversize(self): assert len(count_results) == 1 -@CrossSync.export_sync( - path="tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" -) class TestMutationsBatcherAsync: - @CrossSync.convert def _get_target_class(self): - return CrossSync.MutationsBatcher + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) + + return MutationsBatcherAsync def _make_one(self, table=None, **kwargs): from google.api_core.exceptions import DeadlineExceeded @@ -324,140 +303,132 @@ def _make_one(self, table=None, **kwargs): return self._get_target_class()(table, **kwargs) - @staticmethod - def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation - - @CrossSync.pytest - async def test_ctor_defaults(self): - with mock.patch.object( - self._get_target_class(), "_timer_routine", return_value=CrossSync.Future() - ) as flush_timer_mock: - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = [Exception] - async with self._make_one(table) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._max_mutation_count == 100000 - assert instance._flow_control._max_mutation_bytes == 104857600 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert ( - instance._operation_timeout - == table.default_mutate_rows_operation_timeout - ) - assert ( - instance._attempt_timeout - == table.default_mutate_rows_attempt_timeout - ) - assert ( - instance._retryable_errors - == table.default_mutate_rows_retryable_errors - ) - await CrossSync.yield_to_event_loop() - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == 5 - assert isinstance(instance._flush_timer, CrossSync.Future) + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer" + ) + @pytest.mark.asyncio + async def test_ctor_defaults(self, flush_timer_mock): + flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = [Exception] + async with self._make_one(table) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._max_mutation_count == 100000 + assert instance._flow_control._max_mutation_bytes == 104857600 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert ( + instance._operation_timeout + == table.default_mutate_rows_operation_timeout + ) + assert ( + instance._attempt_timeout == table.default_mutate_rows_attempt_timeout + ) + assert ( + instance._retryable_errors == table.default_mutate_rows_retryable_errors + ) + await asyncio.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == 5 + assert isinstance(instance._flush_timer, asyncio.Future) - @CrossSync.pytest - async def test_ctor_explicit(self): + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer", + ) + @pytest.mark.asyncio + async def test_ctor_explicit(self, flush_timer_mock): """Test with explicit parameters""" - with mock.patch.object( - self._get_target_class(), "_timer_routine", return_value=CrossSync.Future() - ) as flush_timer_mock: - table = mock.Mock() - flush_interval = 20 - flush_limit_count = 17 - flush_limit_bytes = 19 - flow_control_max_mutation_count = 1001 - flow_control_max_bytes = 12 - operation_timeout = 11 - attempt_timeout = 2 - retryable_errors = [Exception] - async with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - flow_control_max_mutation_count=flow_control_max_mutation_count, - flow_control_max_bytes=flow_control_max_bytes, - batch_operation_timeout=operation_timeout, - batch_attempt_timeout=attempt_timeout, - batch_retryable_errors=retryable_errors, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert ( - instance._flow_control._max_mutation_count - == flow_control_max_mutation_count - ) - assert ( - instance._flow_control._max_mutation_bytes == flow_control_max_bytes - ) - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert instance._operation_timeout == operation_timeout - assert instance._attempt_timeout == attempt_timeout - assert instance._retryable_errors == retryable_errors - await CrossSync.yield_to_event_loop() - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == flush_interval - assert isinstance(instance._flush_timer, CrossSync.Future) - - @CrossSync.pytest - async def test_ctor_no_flush_limits(self): + flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) + table = mock.Mock() + flush_interval = 20 + flush_limit_count = 17 + flush_limit_bytes = 19 + flow_control_max_mutation_count = 1001 + flow_control_max_bytes = 12 + operation_timeout = 11 + attempt_timeout = 2 + retryable_errors = [Exception] + async with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=operation_timeout, + batch_attempt_timeout=attempt_timeout, + batch_retryable_errors=retryable_errors, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert ( + instance._flow_control._max_mutation_count + == flow_control_max_mutation_count + ) + assert instance._flow_control._max_mutation_bytes == flow_control_max_bytes + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert instance._operation_timeout == operation_timeout + assert instance._attempt_timeout == attempt_timeout + assert instance._retryable_errors == retryable_errors + await asyncio.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == flush_interval + assert isinstance(instance._flush_timer, asyncio.Future) + + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer" + ) + @pytest.mark.asyncio + async def test_ctor_no_flush_limits(self, flush_timer_mock): """Test with None for flush limits""" - with mock.patch.object( - self._get_target_class(), "_timer_routine", return_value=CrossSync.Future() - ) as flush_timer_mock: - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = () - flush_interval = None - flush_limit_count = None - flush_limit_bytes = None - async with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._staged_entries == [] - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - await CrossSync.yield_to_event_loop() - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] is None - assert isinstance(instance._flush_timer, CrossSync.Future) + flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = () + flush_interval = None + flush_limit_count = None + flush_limit_bytes = None + async with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._staged_entries == [] + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + await asyncio.sleep(0) + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] is None + assert isinstance(instance._flush_timer, asyncio.Future) - @CrossSync.pytest + @pytest.mark.asyncio async def test_ctor_invalid_values(self): """Test that timeout values are positive, and fit within expected limits""" with pytest.raises(ValueError) as e: @@ -467,21 +438,24 @@ async def test_ctor_invalid_values(self): self._make_one(batch_attempt_timeout=-1) assert "attempt_timeout must be greater than 0" in str(e.value) - @CrossSync.convert def test_default_argument_consistency(self): """ We supply default arguments in MutationsBatcherAsync.__init__, and in table.mutations_batcher. Make sure any changes to defaults are applied to both places """ + from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) import inspect get_batcher_signature = dict( - inspect.signature(CrossSync.Table.mutations_batcher).parameters + inspect.signature(TableAsync.mutations_batcher).parameters ) get_batcher_signature.pop("self") batcher_init_signature = dict( - inspect.signature(self._get_target_class()).parameters + inspect.signature(MutationsBatcherAsync).parameters ) batcher_init_signature.pop("table") # both should have same number of arguments @@ -496,96 +470,97 @@ def test_default_argument_consistency(self): == batcher_init_signature[arg_name].default ) - @CrossSync.pytest - @pytest.mark.parametrize("input_val", [None, 0, -1]) - async def test__start_flush_timer_w_empty_input(self, input_val): - """Empty/invalid timer should return immediately""" - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: - # mock different method depending on sync vs async - async with self._make_one() as instance: - if CrossSync.is_async: - sleep_obj, sleep_method = asyncio, "wait_for" - else: - sleep_obj, sleep_method = instance._closed, "wait" - with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: - result = await instance._timer_routine(input_val) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 - assert result is None + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" + ) + @pytest.mark.asyncio + async def test__start_flush_timer_w_None(self, flush_mock): + """Empty timer should return immediately""" + async with self._make_one() as instance: + with mock.patch("asyncio.sleep") as sleep_mock: + await instance._start_flush_timer(None) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 - @CrossSync.pytest - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - async def test__start_flush_timer_call_when_closed( - self, - ): + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" + ) + @pytest.mark.asyncio + async def test__start_flush_timer_call_when_closed(self, flush_mock): """closed batcher's timer should return immediately""" - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: - async with self._make_one() as instance: - await instance.close() - flush_mock.reset_mock() - # mock different method depending on sync vs async - if CrossSync.is_async: - sleep_obj, sleep_method = asyncio, "wait_for" - else: - sleep_obj, sleep_method = instance._closed, "wait" - with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: - await instance._timer_routine(10) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 + async with self._make_one() as instance: + await instance.close() + flush_mock.reset_mock() + with mock.patch("asyncio.sleep") as sleep_mock: + await instance._start_flush_timer(1) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 - @CrossSync.pytest - @pytest.mark.parametrize("num_staged", [0, 1, 10]) - @pytest.mark.filterwarnings("ignore::RuntimeWarning") - async def test__flush_timer(self, num_staged): + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" + ) + @pytest.mark.asyncio + async def test__flush_timer(self, flush_mock): """Timer should continue to call _schedule_flush in a loop""" - from google.cloud.bigtable.data._sync.cross_sync import CrossSync + expected_sleep = 12 + async with self._make_one(flush_interval=expected_sleep) as instance: + instance._staged_entries = [mock.Mock()] + loop_num = 3 + with mock.patch("asyncio.sleep") as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [asyncio.CancelledError()] + try: + await instance._flush_timer + except asyncio.CancelledError: + pass + assert sleep_mock.call_count == loop_num + 1 + sleep_mock.assert_called_with(expected_sleep) + assert flush_mock.call_count == loop_num + + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" + ) + @pytest.mark.asyncio + async def test__flush_timer_no_mutations(self, flush_mock): + """Timer should not flush if no new mutations have been staged""" + expected_sleep = 12 + async with self._make_one(flush_interval=expected_sleep) as instance: + loop_num = 3 + with mock.patch("asyncio.sleep") as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [asyncio.CancelledError()] + try: + await instance._flush_timer + except asyncio.CancelledError: + pass + assert sleep_mock.call_count == loop_num + 1 + sleep_mock.assert_called_with(expected_sleep) + assert flush_mock.call_count == 0 - with mock.patch.object( - self._get_target_class(), "_schedule_flush" - ) as flush_mock: - expected_sleep = 12 - async with self._make_one(flush_interval=expected_sleep) as instance: - loop_num = 3 - instance._staged_entries = [mock.Mock()] * num_staged - with mock.patch.object(CrossSync, "event_wait") as sleep_mock: - sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] - with pytest.raises(TabError): - await self._get_target_class()._timer_routine( - instance, expected_sleep - ) - if CrossSync.is_async: - # replace with np-op so there are no issues on close - instance._flush_timer = CrossSync.Future() - assert sleep_mock.call_count == loop_num + 1 - sleep_kwargs = sleep_mock.call_args[1] - assert sleep_kwargs["timeout"] == expected_sleep - assert flush_mock.call_count == (0 if num_staged == 0 else loop_num) - - @CrossSync.pytest - async def test__flush_timer_close(self): + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" + ) + @pytest.mark.asyncio + async def test__flush_timer_close(self, flush_mock): """Timer should continue terminate after close""" - with mock.patch.object(self._get_target_class(), "_schedule_flush"): - async with self._make_one() as instance: + async with self._make_one() as instance: + with mock.patch("asyncio.sleep"): # let task run in background + await asyncio.sleep(0.5) assert instance._flush_timer.done() is False # close the batcher await instance.close() + await asyncio.sleep(0.1) # task should be complete assert instance._flush_timer.done() is True - @CrossSync.pytest + @pytest.mark.asyncio async def test_append_closed(self): """Should raise exception""" - instance = self._make_one() - await instance.close() with pytest.raises(RuntimeError): + instance = self._make_one() + await instance.close() await instance.append(mock.Mock()) - @CrossSync.pytest + @pytest.mark.asyncio async def test_append_wrong_mutation(self): """ Mutation objects should raise an exception. @@ -599,13 +574,13 @@ async def test_append_wrong_mutation(self): await instance.append(DeleteAllFromRow()) assert str(e.value) == expected_error - @CrossSync.pytest + @pytest.mark.asyncio async def test_append_outside_flow_limits(self): """entries larger than mutation limits are still processed""" async with self._make_one( flow_control_max_mutation_count=1, flow_control_max_bytes=1 ) as instance: - oversized_entry = self._make_mutation(count=0, size=2) + oversized_entry = _make_mutation(count=0, size=2) await instance.append(oversized_entry) assert instance._staged_entries == [oversized_entry] assert instance._staged_count == 0 @@ -614,21 +589,25 @@ async def test_append_outside_flow_limits(self): async with self._make_one( flow_control_max_mutation_count=1, flow_control_max_bytes=1 ) as instance: - overcount_entry = self._make_mutation(count=2, size=0) + overcount_entry = _make_mutation(count=2, size=0) await instance.append(overcount_entry) assert instance._staged_entries == [overcount_entry] assert instance._staged_count == 2 assert instance._staged_bytes == 0 instance._staged_entries = [] - @CrossSync.pytest + @pytest.mark.asyncio async def test_append_flush_runs_after_limit_hit(self): """ If the user appends a bunch of entries above the flush limits back-to-back, it should still flush in a single task """ + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) + with mock.patch.object( - self._get_target_class(), "_execute_mutate_rows" + MutationsBatcherAsync, "_execute_mutate_rows" ) as op_mock: async with self._make_one(flush_limit_bytes=100) as instance: # mock network calls @@ -637,13 +616,13 @@ async def mock_call(*args, **kwargs): op_mock.side_effect = mock_call # append a mutation just under the size limit - await instance.append(self._make_mutation(size=99)) + await instance.append(_make_mutation(size=99)) # append a bunch of entries back-to-back in a loop num_entries = 10 for _ in range(num_entries): - await instance.append(self._make_mutation(size=1)) + await instance.append(_make_mutation(size=1)) # let any flush jobs finish - await instance._wait_for_batch_results(*instance._flush_jobs) + await asyncio.gather(*instance._flush_jobs) # should have only flushed once, with large mutation and first mutation in loop assert op_mock.call_count == 1 sent_batch = op_mock.call_args[0][0] @@ -663,8 +642,7 @@ async def mock_call(*args, **kwargs): (1, 1, 0, 0, False), ], ) - @CrossSync.pytest - @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.asyncio async def test_append( self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush ): @@ -675,7 +653,7 @@ async def test_append( assert instance._staged_count == 0 assert instance._staged_bytes == 0 assert instance._staged_entries == [] - mutation = self._make_mutation(count=mutation_count, size=mutation_bytes) + mutation = _make_mutation(count=mutation_count, size=mutation_bytes) with mock.patch.object(instance, "_schedule_flush") as flush_mock: await instance.append(mutation) assert flush_mock.call_count == bool(expect_flush) @@ -684,7 +662,7 @@ async def test_append( assert instance._staged_entries == [mutation] instance._staged_entries = [] - @CrossSync.pytest + @pytest.mark.asyncio async def test_append_multiple_sequentially(self): """Append multiple mutations""" async with self._make_one( @@ -693,7 +671,7 @@ async def test_append_multiple_sequentially(self): assert instance._staged_count == 0 assert instance._staged_bytes == 0 assert instance._staged_entries == [] - mutation = self._make_mutation(count=2, size=3) + mutation = _make_mutation(count=2, size=3) with mock.patch.object(instance, "_schedule_flush") as flush_mock: await instance.append(mutation) assert flush_mock.call_count == 0 @@ -712,7 +690,7 @@ async def test_append_multiple_sequentially(self): assert len(instance._staged_entries) == 3 instance._staged_entries = [] - @CrossSync.pytest + @pytest.mark.asyncio async def test_flush_flow_control_concurrent_requests(self): """ requests should happen in parallel if flow control breaks up single flush into batches @@ -720,14 +698,14 @@ async def test_flush_flow_control_concurrent_requests(self): import time num_calls = 10 - fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] + fake_mutations = [_make_mutation(count=1) for _ in range(num_calls)] async with self._make_one(flow_control_max_mutation_count=1) as instance: with mock.patch.object( - instance, "_execute_mutate_rows", CrossSync.Mock() + instance, "_execute_mutate_rows", AsyncMock() ) as op_mock: # mock network calls async def mock_call(*args, **kwargs): - await CrossSync.sleep(0.1) + await asyncio.sleep(0.1) return [] op_mock.side_effect = mock_call @@ -735,15 +713,15 @@ async def mock_call(*args, **kwargs): # flush one large batch, that will be broken up into smaller batches instance._staged_entries = fake_mutations instance._schedule_flush() - await CrossSync.sleep(0.01) + await asyncio.sleep(0.01) # make room for new mutations for i in range(num_calls): await instance._flow_control.remove_from_flow( - [self._make_mutation(count=1)] + [_make_mutation(count=1)] ) - await CrossSync.sleep(0.01) + await asyncio.sleep(0.01) # allow flushes to complete - await instance._wait_for_batch_results(*instance._flush_jobs) + await asyncio.gather(*instance._flush_jobs) duration = time.monotonic() - start_time assert len(instance._oldest_exceptions) == 0 assert len(instance._newest_exceptions) == 0 @@ -751,7 +729,7 @@ async def mock_call(*args, **kwargs): assert duration < 0.5 assert op_mock.call_count == num_calls - @CrossSync.pytest + @pytest.mark.asyncio async def test_schedule_flush_no_mutations(self): """schedule flush should return None if no staged mutations""" async with self._make_one() as instance: @@ -760,15 +738,11 @@ async def test_schedule_flush_no_mutations(self): assert instance._schedule_flush() is None assert flush_mock.call_count == 0 - @CrossSync.pytest - @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.asyncio async def test_schedule_flush_with_mutations(self): """if new mutations exist, should add a new flush task to _flush_jobs""" async with self._make_one() as instance: with mock.patch.object(instance, "_flush_internal") as flush_mock: - if not CrossSync.is_async: - # simulate operation - flush_mock.side_effect = lambda x: time.sleep(0.1) for i in range(1, 4): mutation = mock.Mock() instance._staged_entries = [mutation] @@ -779,10 +753,9 @@ async def test_schedule_flush_with_mutations(self): assert instance._staged_entries == [] assert instance._staged_count == 0 assert instance._staged_bytes == 0 - assert flush_mock.call_count == 1 - flush_mock.reset_mock() + assert flush_mock.call_count == i - @CrossSync.pytest + @pytest.mark.asyncio async def test__flush_internal(self): """ _flush_internal should: @@ -802,7 +775,7 @@ async def gen(x): yield x flow_mock.side_effect = lambda x: gen(x) - mutations = [self._make_mutation(count=1, size=1)] * num_entries + mutations = [_make_mutation(count=1, size=1)] * num_entries await instance._flush_internal(mutations) assert instance._entries_processed_since_last_raise == num_entries assert execute_mock.call_count == 1 @@ -810,28 +783,20 @@ async def gen(x): instance._oldest_exceptions.clear() instance._newest_exceptions.clear() - @CrossSync.pytest + @pytest.mark.asyncio async def test_flush_clears_job_list(self): """ a job should be added to _flush_jobs when _schedule_flush is called, and removed when it completes """ async with self._make_one() as instance: - with mock.patch.object( - instance, "_flush_internal", CrossSync.Mock() - ) as flush_mock: - if not CrossSync.is_async: - # simulate operation - flush_mock.side_effect = lambda x: time.sleep(0.1) - mutations = [self._make_mutation(count=1, size=1)] + with mock.patch.object(instance, "_flush_internal", AsyncMock()): + mutations = [_make_mutation(count=1, size=1)] instance._staged_entries = mutations assert instance._flush_jobs == set() new_job = instance._schedule_flush() assert instance._flush_jobs == {new_job} - if CrossSync.is_async: - await new_job - else: - new_job.result() + await new_job assert instance._flush_jobs == set() @pytest.mark.parametrize( @@ -846,7 +811,7 @@ async def test_flush_clears_job_list(self): (10, 20, 20), # should cap at 20 ], ) - @CrossSync.pytest + @pytest.mark.asyncio async def test__flush_internal_with_errors( self, num_starting, num_new_errors, expected_total_errors ): @@ -871,7 +836,7 @@ async def gen(x): yield x flow_mock.side_effect = lambda x: gen(x) - mutations = [self._make_mutation(count=1, size=1)] * num_entries + mutations = [_make_mutation(count=1, size=1)] * num_entries await instance._flush_internal(mutations) assert instance._entries_processed_since_last_raise == num_entries assert execute_mock.call_count == 1 @@ -888,7 +853,6 @@ async def gen(x): instance._oldest_exceptions.clear() instance._newest_exceptions.clear() - @CrossSync.convert async def _mock_gapic_return(self, num=5): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -902,11 +866,11 @@ async def gen(num): return gen(num) - @CrossSync.pytest + @pytest.mark.asyncio async def test_timer_flush_end_to_end(self): """Flush should automatically trigger after flush_interval""" - num_mutations = 10 - mutations = [self._make_mutation(count=2, size=2)] * num_mutations + num_nutations = 10 + mutations = [_make_mutation(count=2, size=2)] * num_nutations async with self._make_one(flush_interval=0.05) as instance: instance._table.default_operation_timeout = 10 @@ -915,65 +879,69 @@ async def test_timer_flush_end_to_end(self): instance._table.client._gapic_client, "mutate_rows" ) as gapic_mock: gapic_mock.side_effect = ( - lambda *args, **kwargs: self._mock_gapic_return(num_mutations) + lambda *args, **kwargs: self._mock_gapic_return(num_nutations) ) for m in mutations: await instance.append(m) assert instance._entries_processed_since_last_raise == 0 # let flush trigger due to timer - await CrossSync.sleep(0.1) - assert instance._entries_processed_since_last_raise == num_mutations - - @CrossSync.pytest - async def test__execute_mutate_rows(self): - with mock.patch.object(CrossSync, "_MutateRowsOperation") as mutate_rows: - mutate_rows.return_value = CrossSync.Mock() - start_operation = mutate_rows().start - table = mock.Mock() - table.table_name = "test-table" - table.app_profile_id = "test-app-profile" - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - async with self._make_one(table) as instance: - batch = [self._make_mutation()] - result = await instance._execute_mutate_rows(batch) - assert start_operation.call_count == 1 - args, kwargs = mutate_rows.call_args - assert args[0] == table.client._gapic_client - assert args[1] == table - assert args[2] == batch - kwargs["operation_timeout"] == 17 - kwargs["attempt_timeout"] == 13 - assert result == [] - - @CrossSync.pytest - async def test__execute_mutate_rows_returns_errors(self): + await asyncio.sleep(0.1) + assert instance._entries_processed_since_last_raise == num_nutations + + @pytest.mark.asyncio + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync", + ) + async def test__execute_mutate_rows(self, mutate_rows): + mutate_rows.return_value = AsyncMock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + async with self._make_one(table) as instance: + batch = [_make_mutation()] + result = await instance._execute_mutate_rows(batch) + assert start_operation.call_count == 1 + args, kwargs = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + kwargs["operation_timeout"] == 17 + kwargs["attempt_timeout"] == 13 + assert result == [] + + @pytest.mark.asyncio + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync.start" + ) + async def test__execute_mutate_rows_returns_errors(self, mutate_rows): """Errors from operation should be retruned as list""" from google.cloud.bigtable.data.exceptions import ( MutationsExceptionGroup, FailedMutationEntryError, ) - with mock.patch.object(CrossSync._MutateRowsOperation, "start") as mutate_rows: - err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) - err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) - mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - async with self._make_one(table) as instance: - batch = [self._make_mutation()] - result = await instance._execute_mutate_rows(batch) - assert len(result) == 2 - assert result[0] == err1 - assert result[1] == err2 - # indices should be set to None - assert result[0].index is None - assert result[1].index is None - - @CrossSync.pytest + err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) + err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + async with self._make_one(table) as instance: + batch = [_make_mutation()] + result = await instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + # indices should be set to None + assert result[0].index is None + assert result[1].index is None + + @pytest.mark.asyncio async def test__raise_exceptions(self): """Raise exceptions and reset error state""" from google.cloud.bigtable.data import exceptions @@ -993,19 +961,13 @@ async def test__raise_exceptions(self): # try calling again instance._raise_exceptions() - @CrossSync.pytest - @CrossSync.convert( - sync_name="test___enter__", replace_symbols={"__aenter__": "__enter__"} - ) + @pytest.mark.asyncio async def test___aenter__(self): """Should return self""" async with self._make_one() as instance: assert await instance.__aenter__() == instance - @CrossSync.pytest - @CrossSync.convert( - sync_name="test___exit__", replace_symbols={"__aexit__": "__exit__"} - ) + @pytest.mark.asyncio async def test___aexit__(self): """aexit should call close""" async with self._make_one() as instance: @@ -1013,7 +975,7 @@ async def test___aexit__(self): await instance.__aexit__(None, None, None) assert close_mock.call_count == 1 - @CrossSync.pytest + @pytest.mark.asyncio async def test_close(self): """Should clean up all resources""" async with self._make_one() as instance: @@ -1026,7 +988,7 @@ async def test_close(self): assert flush_mock.call_count == 1 assert raise_mock.call_count == 1 - @CrossSync.pytest + @pytest.mark.asyncio async def test_close_w_exceptions(self): """Raise exceptions on close""" from google.cloud.bigtable.data import exceptions @@ -1045,7 +1007,7 @@ async def test_close_w_exceptions(self): # clear out exceptions instance._oldest_exceptions, instance._newest_exceptions = ([], []) - @CrossSync.pytest + @pytest.mark.asyncio async def test__on_exit(self, recwarn): """Should raise warnings if unflushed mutations exist""" async with self._make_one() as instance: @@ -1061,13 +1023,13 @@ async def test__on_exit(self, recwarn): assert "unflushed mutations" in str(w[0].message).lower() assert str(num_left) in str(w[0].message) # calling while closed is noop - instance._closed.set() + instance.closed = True instance._on_exit() assert len(recwarn) == 0 # reset staged mutations for cleanup instance._staged_entries = [] - @CrossSync.pytest + @pytest.mark.asyncio async def test_atexit_registration(self): """Should run _on_exit on program termination""" import atexit @@ -1077,29 +1039,30 @@ async def test_atexit_registration(self): async with self._make_one(): assert register_mock.call_count == 1 - @CrossSync.pytest - async def test_timeout_args_passed(self): + @pytest.mark.asyncio + @mock.patch( + "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync", + ) + async def test_timeout_args_passed(self, mutate_rows): """ batch_operation_timeout and batch_attempt_timeout should be used in api calls """ - with mock.patch.object( - CrossSync, "_MutateRowsOperation", return_value=CrossSync.Mock() - ) as mutate_rows: - expected_operation_timeout = 17 - expected_attempt_timeout = 13 - async with self._make_one( - batch_operation_timeout=expected_operation_timeout, - batch_attempt_timeout=expected_attempt_timeout, - ) as instance: - assert instance._operation_timeout == expected_operation_timeout - assert instance._attempt_timeout == expected_attempt_timeout - # make simulated gapic call - await instance._execute_mutate_rows([self._make_mutation()]) - assert mutate_rows.call_count == 1 - kwargs = mutate_rows.call_args[1] - assert kwargs["operation_timeout"] == expected_operation_timeout - assert kwargs["attempt_timeout"] == expected_attempt_timeout + mutate_rows.return_value = AsyncMock() + expected_operation_timeout = 17 + expected_attempt_timeout = 13 + async with self._make_one( + batch_operation_timeout=expected_operation_timeout, + batch_attempt_timeout=expected_attempt_timeout, + ) as instance: + assert instance._operation_timeout == expected_operation_timeout + assert instance._attempt_timeout == expected_attempt_timeout + # make simulated gapic call + await instance._execute_mutate_rows([_make_mutation()]) + assert mutate_rows.call_count == 1 + kwargs = mutate_rows.call_args[1] + assert kwargs["operation_timeout"] == expected_operation_timeout + assert kwargs["attempt_timeout"] == expected_attempt_timeout @pytest.mark.parametrize( "limit,in_e,start_e,end_e", @@ -1160,7 +1123,7 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): for i in range(1, newest_list_diff + 1): assert mock_batcher._newest_exceptions[-i] == input_list[-i] - @CrossSync.pytest + @pytest.mark.asyncio # test different inputs for retryable exceptions @pytest.mark.parametrize( "input_retryables,expected_retryables", @@ -1185,7 +1148,6 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): ([4], [core_exceptions.DeadlineExceeded]), ], ) - @CrossSync.convert async def test_customizable_retryable_errors( self, input_retryables, expected_retryables ): @@ -1193,21 +1155,25 @@ async def test_customizable_retryable_errors( Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - with mock.patch.object( - google.api_core.retry, "if_exception_type" + from google.cloud.bigtable.data._async.client import TableAsync + + with mock.patch( + "google.api_core.retry.if_exception_type" ) as predicate_builder_mock: - with mock.patch.object(CrossSync, "retry_target") as retry_fn_mock: + with mock.patch( + "google.api_core.retry.retry_target_async" + ) as retry_fn_mock: table = None with mock.patch("asyncio.create_task"): - table = CrossSync.Table(mock.Mock(), "instance", "table") + table = TableAsync(mock.Mock(), "instance", "table") async with self._make_one( table, batch_retryable_errors=input_retryables ) as instance: assert instance._retryable_errors == expected_retryables - expected_predicate = expected_retryables.__contains__ + expected_predicate = lambda a: a in expected_retryables # noqa predicate_builder_mock.return_value = expected_predicate retry_fn_mock.side_effect = RuntimeError("stop early") - mutation = self._make_mutation(count=1, size=1) + mutation = _make_mutation(count=1, size=1) await instance._execute_mutate_rows([mutation]) # passed in errors should be used to build the predicate predicate_builder_mock.assert_called_once_with( diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py deleted file mode 100644 index b30f7544f..000000000 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ /dev/null @@ -1,351 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import os -import warnings -import pytest -import mock - -from itertools import zip_longest - -from google.cloud.bigtable_v2 import ReadRowsResponse - -from google.cloud.bigtable.data.exceptions import InvalidChunk -from google.cloud.bigtable.data.row import Row - -from ...v2_client.test_row_merger import ReadRowsTest, TestFile - -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - - -@CrossSync.export_sync( - path="tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", -) -class TestReadRowsAcceptanceAsync: - @staticmethod - @CrossSync.convert - def _get_operation_class(): - return CrossSync._ReadRowsOperation - - @staticmethod - @CrossSync.convert - def _get_client_class(): - return CrossSync.DataClient - - def parse_readrows_acceptance_tests(): - dirname = os.path.dirname(__file__) - filename = os.path.join(dirname, "../read-rows-acceptance-test.json") - - with open(filename) as json_file: - test_json = TestFile.from_json(json_file.read()) - return test_json.read_rows_tests - - @staticmethod - def extract_results_from_row(row: Row): - results = [] - for family, col, cells in row.items(): - for cell in cells: - results.append( - ReadRowsTest.Result( - row_key=row.row_key, - family_name=family, - qualifier=col, - timestamp_micros=cell.timestamp_ns // 1000, - value=cell.value, - label=(cell.labels[0] if cell.labels else ""), - ) - ) - return results - - @staticmethod - @CrossSync.convert - async def _coro_wrapper(stream): - return stream - - @CrossSync.convert - async def _process_chunks(self, *chunks): - async def _row_stream(): - yield ReadRowsResponse(chunks=chunks) - - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = None - chunker = self._get_operation_class().chunk_stream( - instance, self._coro_wrapper(_row_stream()) - ) - merger = self._get_operation_class().merge_rows(chunker) - results = [] - async for row in merger: - results.append(row) - return results - - @pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description - ) - @CrossSync.pytest - async def test_row_merger_scenario(self, test_case: ReadRowsTest): - async def _scenerio_stream(): - for chunk in test_case.chunks: - yield ReadRowsResponse(chunks=[chunk]) - - try: - results = [] - instance = mock.Mock() - instance._last_yielded_row_key = None - instance._remaining_count = None - chunker = self._get_operation_class().chunk_stream( - instance, self._coro_wrapper(_scenerio_stream()) - ) - merger = self._get_operation_class().merge_rows(chunker) - async for row in merger: - for cell in row: - cell_result = ReadRowsTest.Result( - row_key=cell.row_key, - family_name=cell.family, - qualifier=cell.qualifier, - timestamp_micros=cell.timestamp_micros, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - @pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description - ) - @CrossSync.pytest - async def test_read_rows_scenario(self, test_case: ReadRowsTest): - async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): - from google.cloud.bigtable_v2 import ReadRowsResponse - - class mock_stream: - def __init__(self, chunk_list): - self.chunk_list = chunk_list - self.idx = -1 - - def __aiter__(self): - return self - - def __iter__(self): - return self - - async def __anext__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - chunk = self.chunk_list[self.idx] - return ReadRowsResponse(chunks=[chunk]) - raise CrossSync.StopIteration - - def __next__(self): - return self.__anext__() - - def cancel(self): - pass - - return mock_stream(chunk_list) - - with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # use emulator mode to avoid auth issues in CI - client = self._get_client_class()() - try: - table = client.get_table("instance", "table") - results = [] - with mock.patch.object( - table.client._gapic_client, "read_rows" - ) as read_rows: - # run once, then return error on retry - read_rows.return_value = _make_gapic_stream(test_case.chunks) - async for row in await table.read_rows_stream(query={}): - for cell in row: - cell_result = ReadRowsTest.Result( - row_key=cell.row_key, - family_name=cell.family, - qualifier=cell.qualifier, - timestamp_micros=cell.timestamp_micros, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - finally: - await client.close() - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - @CrossSync.pytest - async def test_out_of_order_rows(self): - async def _row_stream(): - yield ReadRowsResponse(last_scanned_row_key=b"a") - - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = b"b" - chunker = self._get_operation_class().chunk_stream( - instance, self._coro_wrapper(_row_stream()) - ) - merger = self._get_operation_class().merge_rows(chunker) - with pytest.raises(InvalidChunk): - async for _ in merger: - pass - - @CrossSync.pytest - async def test_bare_reset(self): - first_chunk = ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk( - row_key=b"a", family_name="f", qualifier=b"q", value=b"v" - ) - ) - with pytest.raises(InvalidChunk): - await self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") - ), - ) - with pytest.raises(InvalidChunk): - await self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, family_name="f") - ), - ) - with pytest.raises(InvalidChunk): - await self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") - ), - ) - with pytest.raises(InvalidChunk): - await self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) - ), - ) - with pytest.raises(InvalidChunk): - await self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) - ), - ) - with pytest.raises(InvalidChunk): - await self._process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, value=b"v") - ), - ) - - @CrossSync.pytest - async def test_missing_family(self): - with pytest.raises(InvalidChunk): - await self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - qualifier=b"q", - timestamp_micros=1000, - value=b"v", - commit_row=True, - ) - ) - - @CrossSync.pytest - async def test_mid_cell_row_key_change(self): - with pytest.raises(InvalidChunk): - await self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), - ) - - @CrossSync.pytest - async def test_mid_cell_family_change(self): - with pytest.raises(InvalidChunk): - await self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - family_name="f2", value=b"v", commit_row=True - ), - ) - - @CrossSync.pytest - async def test_mid_cell_qualifier_change(self): - with pytest.raises(InvalidChunk): - await self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - qualifier=b"q2", value=b"v", commit_row=True - ), - ) - - @CrossSync.pytest - async def test_mid_cell_timestamp_change(self): - with pytest.raises(InvalidChunk): - await self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - timestamp_micros=2000, value=b"v", commit_row=True - ), - ) - - @CrossSync.pytest - async def test_mid_cell_labels_change(self): - with pytest.raises(InvalidChunk): - await self._process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), - ) diff --git a/tests/unit/data/test_read_rows_acceptance.py b/tests/unit/data/test_read_rows_acceptance.py new file mode 100644 index 000000000..7cb3c08dc --- /dev/null +++ b/tests/unit/data/test_read_rows_acceptance.py @@ -0,0 +1,331 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +from itertools import zip_longest + +import pytest +import mock + +from google.cloud.bigtable_v2 import ReadRowsResponse + +from google.cloud.bigtable.data._async.client import BigtableDataClientAsync +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync +from google.cloud.bigtable.data.row import Row + +from ..v2_client.test_row_merger import ReadRowsTest, TestFile + + +def parse_readrows_acceptance_tests(): + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "./read-rows-acceptance-test.json") + + with open(filename) as json_file: + test_json = TestFile.from_json(json_file.read()) + return test_json.read_rows_tests + + +def extract_results_from_row(row: Row): + results = [] + for family, col, cells in row.items(): + for cell in cells: + results.append( + ReadRowsTest.Result( + row_key=row.row_key, + family_name=family, + qualifier=col, + timestamp_micros=cell.timestamp_ns // 1000, + value=cell.value, + label=(cell.labels[0] if cell.labels else ""), + ) + ) + return results + + +@pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description +) +@pytest.mark.asyncio +async def test_row_merger_scenario(test_case: ReadRowsTest): + async def _scenerio_stream(): + for chunk in test_case.chunks: + yield ReadRowsResponse(chunks=[chunk]) + + try: + results = [] + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + chunker = _ReadRowsOperationAsync.chunk_stream( + instance, _coro_wrapper(_scenerio_stream()) + ) + merger = _ReadRowsOperationAsync.merge_rows(chunker) + async for row in merger: + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + +@pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description +) +@pytest.mark.asyncio +async def test_read_rows_scenario(test_case: ReadRowsTest): + async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list): + self.chunk_list = chunk_list + self.idx = -1 + + def __aiter__(self): + return self + + async def __anext__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + chunk = self.chunk_list[self.idx] + return ReadRowsResponse(chunks=[chunk]) + raise StopAsyncIteration + + def cancel(self): + pass + + return mock_stream(chunk_list) + + try: + with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): + # use emulator mode to avoid auth issues in CI + client = BigtableDataClientAsync() + table = client.get_table("instance", "table") + results = [] + with mock.patch.object(table.client._gapic_client, "read_rows") as read_rows: + # run once, then return error on retry + read_rows.return_value = _make_gapic_stream(test_case.chunks) + async for row in await table.read_rows_stream(query={}): + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + finally: + await client.close() + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + +@pytest.mark.asyncio +async def test_out_of_order_rows(): + async def _row_stream(): + yield ReadRowsResponse(last_scanned_row_key=b"a") + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = b"b" + chunker = _ReadRowsOperationAsync.chunk_stream( + instance, _coro_wrapper(_row_stream()) + ) + merger = _ReadRowsOperationAsync.merge_rows(chunker) + with pytest.raises(InvalidChunk): + async for _ in merger: + pass + + +@pytest.mark.asyncio +async def test_bare_reset(): + first_chunk = ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk( + row_key=b"a", family_name="f", qualifier=b"q", value=b"v" + ) + ) + with pytest.raises(InvalidChunk): + await _process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") + ), + ) + with pytest.raises(InvalidChunk): + await _process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, family_name="f") + ), + ) + with pytest.raises(InvalidChunk): + await _process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") + ), + ) + with pytest.raises(InvalidChunk): + await _process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) + ), + ) + with pytest.raises(InvalidChunk): + await _process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) + ), + ) + with pytest.raises(InvalidChunk): + await _process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, value=b"v") + ), + ) + + +@pytest.mark.asyncio +async def test_missing_family(): + with pytest.raises(InvalidChunk): + await _process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + qualifier=b"q", + timestamp_micros=1000, + value=b"v", + commit_row=True, + ) + ) + + +@pytest.mark.asyncio +async def test_mid_cell_row_key_change(): + with pytest.raises(InvalidChunk): + await _process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), + ) + + +@pytest.mark.asyncio +async def test_mid_cell_family_change(): + with pytest.raises(InvalidChunk): + await _process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(family_name="f2", value=b"v", commit_row=True), + ) + + +@pytest.mark.asyncio +async def test_mid_cell_qualifier_change(): + with pytest.raises(InvalidChunk): + await _process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(qualifier=b"q2", value=b"v", commit_row=True), + ) + + +@pytest.mark.asyncio +async def test_mid_cell_timestamp_change(): + with pytest.raises(InvalidChunk): + await _process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + timestamp_micros=2000, value=b"v", commit_row=True + ), + ) + + +@pytest.mark.asyncio +async def test_mid_cell_labels_change(): + with pytest.raises(InvalidChunk): + await _process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), + ) + + +async def _coro_wrapper(stream): + return stream + + +async def _process_chunks(*chunks): + async def _row_stream(): + yield ReadRowsResponse(chunks=chunks) + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = None + chunker = _ReadRowsOperationAsync.chunk_stream( + instance, _coro_wrapper(_row_stream()) + ) + merger = _ReadRowsOperationAsync.merge_rows(chunker) + results = [] + async for row in merger: + results.append(row) + return results From af020a2c59d70656d98cca65bda660806eb7a6f9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 17 Jul 2024 17:49:15 -0700 Subject: [PATCH 193/360] Revert "reverted client and test changes" This reverts commit 37b4833736a5e6fc8d2baf96312f694602a922a7. --- google/cloud/bigtable/data/__init__.py | 2 +- .../bigtable/data/_async/_mutate_rows.py | 49 +- .../cloud/bigtable/data/_async/_read_rows.py | 53 +- google/cloud/bigtable/data/_async/client.py | 374 ++++--- .../bigtable/data/_async/mutations_batcher.py | 212 ++-- google/cloud/bigtable/data/_helpers.py | 3 + google/cloud/bigtable/data/exceptions.py | 15 + google/cloud/bigtable/data/mutations.py | 12 + .../transports/pooled_grpc_asyncio.py | 27 +- tests/system/data/__init__.py | 3 + tests/system/data/setup_fixtures.py | 25 - tests/system/data/test_system.py | 942 ----------------- tests/system/data/test_system_async.py | 992 ++++++++++++++++++ tests/unit/data/_async/__init__.py | 0 tests/unit/data/_async/test__mutate_rows.py | 110 +- tests/unit/data/_async/test__read_rows.py | 75 +- tests/unit/data/_async/test_client.py | 900 +++++++++------- .../data/_async/test_mutations_batcher.py | 782 +++++++------- .../data/_async/test_read_rows_acceptance.py | 351 +++++++ tests/unit/data/test_read_rows_acceptance.py | 331 ------ 20 files changed, 2813 insertions(+), 2445 deletions(-) delete mode 100644 tests/system/data/test_system.py create mode 100644 tests/system/data/test_system_async.py create mode 100644 tests/unit/data/_async/__init__.py create mode 100644 tests/unit/data/_async/test_read_rows_acceptance.py delete mode 100644 tests/unit/data/test_read_rows_acceptance.py diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index 5229f8021..66fe3479b 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -50,10 +50,10 @@ __all__ = ( "BigtableDataClientAsync", "TableAsync", + "MutationsBatcherAsync", "RowKeySamples", "ReadRowsQuery", "RowRange", - "MutationsBatcherAsync", "Mutation", "RowMutationEntry", "SetCell", diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 99b9944cd..e62d43397 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -15,12 +15,10 @@ from __future__ import annotations from typing import Sequence, TYPE_CHECKING -from dataclasses import dataclass import functools from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries -import google.cloud.bigtable_v2.types.bigtable as types_pb import google.cloud.bigtable.data.exceptions as bt_exceptions from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _attempt_timeout_generator @@ -28,25 +26,25 @@ # mutate_rows requests are limited to this number of mutations from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable.data.mutations import _EntryWithProto + +from google.cloud.bigtable.data._sync.cross_sync import CrossSync if TYPE_CHECKING: - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) from google.cloud.bigtable.data.mutations import RowMutationEntry - from google.cloud.bigtable.data._async.client import TableAsync - -@dataclass -class _EntryWithProto: - """ - A dataclass to hold a RowMutationEntry and its corresponding proto representation. - """ + if CrossSync.is_async: + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) - entry: RowMutationEntry - proto: types_pb.MutateRowsRequest.Entry + CrossSync.add_mapping("GapicClient", BigtableAsyncClient) +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", + add_mapping_for_name="_MutateRowsOperation", +) class _MutateRowsOperationAsync: """ MutateRowsOperation manages the logic of sending a set of row mutations, @@ -66,10 +64,11 @@ class _MutateRowsOperationAsync: If not specified, the request will run until operation_timeout is reached. """ + @CrossSync.convert def __init__( self, - gapic_client: "BigtableAsyncClient", - table: "TableAsync", + gapic_client: "CrossSync.GapicClient", + table: "CrossSync.Table", mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, @@ -100,7 +99,7 @@ def __init__( bt_exceptions._MutateRowsIncomplete, ) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - self._operation = retries.retry_target_async( + self._operation = lambda: CrossSync.retry_target( self._run_attempt, self.is_retryable, sleep_generator, @@ -115,6 +114,7 @@ def __init__( self.remaining_indices = list(range(len(self.mutations))) self.errors: dict[int, list[Exception]] = {} + @CrossSync.convert async def start(self): """ Start the operation, and run until completion @@ -124,7 +124,7 @@ async def start(self): """ try: # trigger mutate_rows - await self._operation + CrossSync.rm_aio(await self._operation()) except Exception as exc: # exceptions raised by retryable are added to the list of exceptions for all unfinalized mutations incomplete_indices = self.remaining_indices.copy() @@ -151,6 +151,7 @@ async def start(self): all_errors, len(self.mutations) ) + @CrossSync.convert async def _run_attempt(self): """ Run a single attempt of the mutate_rows rpc. @@ -171,12 +172,14 @@ async def _run_attempt(self): return # make gapic request try: - result_generator = await self._gapic_fn( - timeout=next(self.timeout_generator), - entries=request_entries, - retry=None, + result_generator = CrossSync.rm_aio( + await self._gapic_fn( + timeout=next(self.timeout_generator), + entries=request_entries, + retry=None, + ) ) - async for result_list in result_generator: + async for result_list in CrossSync.rm_aio(result_generator): for result in result_list.entries: # convert sub-request index to global index orig_idx = active_request_indices[result.index] diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 78cb7a991..2fe48e9e9 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -15,13 +15,7 @@ from __future__ import annotations -from typing import ( - TYPE_CHECKING, - AsyncGenerator, - AsyncIterable, - Awaitable, - Sequence, -) +from typing import Sequence from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB @@ -32,6 +26,7 @@ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.exceptions import _RowSetComplete +from google.cloud.bigtable.data.exceptions import _ResetRow from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory @@ -39,15 +34,13 @@ from google.api_core import retry as retries from google.api_core.retry import exponential_sleep_generator -if TYPE_CHECKING: - from google.cloud.bigtable.data._async.client import TableAsync - - -class _ResetRow(Exception): - def __init__(self, chunk): - self.chunk = chunk +from google.cloud.bigtable.data._sync.cross_sync import CrossSync +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", + add_mapping_for_name="_ReadRowsOperation", +) class _ReadRowsOperationAsync: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream @@ -82,7 +75,7 @@ class _ReadRowsOperationAsync: def __init__( self, query: ReadRowsQuery, - table: "TableAsync", + table: "CrossSync.Table", operation_timeout: float, attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), @@ -108,14 +101,14 @@ def __init__( self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None - def start_operation(self) -> AsyncGenerator[Row, None]: + def start_operation(self) -> CrossSync.Iterable[Row]: """ Start the read_rows operation, retrying on retryable errors. Yields: Row: The next row in the stream """ - return retries.retry_target_stream_async( + return CrossSync.retry_target_stream( self._read_rows_attempt, self._predicate, exponential_sleep_generator(0.01, 60, multiplier=2), @@ -123,7 +116,7 @@ def start_operation(self) -> AsyncGenerator[Row, None]: exception_factory=_retry_exception_factory, ) - def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: + def _read_rows_attempt(self) -> CrossSync.Iterable[Row]: """ Attempt a single read_rows rpc call. This function is intended to be wrapped by retry logic, @@ -159,9 +152,10 @@ def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: chunked_stream = self.chunk_stream(gapic_stream) return self.merge_rows(chunked_stream) + @CrossSync.convert async def chunk_stream( - self, stream: Awaitable[AsyncIterable[ReadRowsResponsePB]] - ) -> AsyncGenerator[ReadRowsResponsePB.CellChunk, None]: + self, stream: CrossSync.Awaitable[CrossSync.Iterable[ReadRowsResponsePB]] + ) -> CrossSync.Iterable[ReadRowsResponsePB.CellChunk]: """ process chunks out of raw read_rows stream @@ -170,7 +164,7 @@ async def chunk_stream( Yields: ReadRowsResponsePB.CellChunk: the next chunk in the stream """ - async for resp in await stream: + async for resp in CrossSync.rm_aio(await stream): # extract proto from proto-plus wrapper resp = resp._pb @@ -211,9 +205,12 @@ async def chunk_stream( current_key = None @staticmethod + @CrossSync.convert( + replace_symbols={"__aiter__": "__iter__", "__anext__": "__next__"} + ) async def merge_rows( - chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None - ) -> AsyncGenerator[Row, None]: + chunks: CrossSync.Iterable[ReadRowsResponsePB.CellChunk] | None, + ) -> CrossSync.Iterable[Row]: """ Merge chunks into rows @@ -228,8 +225,8 @@ async def merge_rows( # For each row while True: try: - c = await it.__anext__() - except StopAsyncIteration: + c = CrossSync.rm_aio(await it.__anext__()) + except CrossSync.StopIteration: # stream complete return row_key = c.row_key @@ -277,7 +274,7 @@ async def merge_rows( buffer = [value] while c.value_size > 0: # throws when premature end - c = await it.__anext__() + c = CrossSync.rm_aio(await it.__anext__()) t = c.timestamp_micros cl = c.labels @@ -309,7 +306,7 @@ async def merge_rows( if c.commit_row: yield Row(row_key, cells) break - c = await it.__anext__() + c = CrossSync.rm_aio(await it.__anext__()) except _ResetRow as e: c = e.chunk if ( @@ -322,7 +319,7 @@ async def merge_rows( ): raise InvalidChunk("reset row with data") continue - except StopAsyncIteration: + except CrossSync.StopIteration: raise InvalidChunk("premature end of stream") @staticmethod diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 34fdf847a..f18b46256 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -25,22 +25,18 @@ TYPE_CHECKING, ) -import asyncio -import grpc import time import warnings -import sys import random import os +import concurrent.futures from functools import partial +from grpc import Channel from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta -from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient -from google.cloud.bigtable_v2.services.bigtable.async_client import DEFAULT_CLIENT_INFO -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - PooledChannel, +from google.cloud.bigtable_v2.services.bigtable.transports.base import ( + DEFAULT_CLIENT_INFO, ) from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest from google.cloud.client import ClientWithProject @@ -49,7 +45,6 @@ from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import ServiceUnavailable from google.api_core.exceptions import Aborted -from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync import google.auth.credentials import google.auth._default @@ -60,8 +55,6 @@ from google.cloud.bigtable.data.exceptions import FailedQueryShardError from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup -from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry -from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import _WarmedInstanceKey from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT @@ -71,21 +64,53 @@ from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import _attempt_timeout_generator -from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync -from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE +from google.cloud.bigtable.data._helpers import _MB_SIZE +from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry + from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule from google.cloud.bigtable.data.row_filters import RowFilter from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter from google.cloud.bigtable.data.row_filters import RowFilterChain +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +if CrossSync.is_async: + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel as AsyncPooledChannel, + ) + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) + + # define file-specific cross-sync replacements + CrossSync.add_mapping("GapicClient", BigtableAsyncClient) + CrossSync.add_mapping("PooledTransport", PooledBigtableGrpcAsyncIOTransport) + CrossSync.add_mapping("PooledChannel", AsyncPooledChannel) + CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) + CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) + CrossSync.add_mapping("MutationsBatcher", MutationsBatcherAsync) + if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync.client.BigtableDataClient", + add_mapping_for_name="DataClient", +) class BigtableDataClientAsync(ClientWithProject): + @CrossSync.convert def __init__( self, *, @@ -120,8 +145,8 @@ def __init__( ValueError: if pool_size is less than 1 """ # set up transport in registry - transport_str = f"pooled_grpc_asyncio_{pool_size}" - transport = PooledBigtableGrpcAsyncIOTransport.with_fixed_size(pool_size) + transport_str = f"bt-{self._client_version()}-{pool_size}" + transport = CrossSync.PooledTransport.with_fixed_size(pool_size) BigtableClientMeta._transport_registry[transport_str] = transport # set up client info headers for veneer library client_info = DEFAULT_CLIENT_INFO @@ -146,22 +171,24 @@ def __init__( project=project, client_options=client_options, ) - self._gapic_client = BigtableAsyncClient( + self._gapic_client = CrossSync.GapicClient( transport=transport_str, credentials=credentials, client_options=client_options, client_info=client_info, ) - self.transport = cast( - PooledBigtableGrpcAsyncIOTransport, self._gapic_client.transport - ) + self._is_closed = CrossSync.Event() + self.transport = cast(CrossSync.PooledTransport, self._gapic_client.transport) # keep track of active instances to for warmup on channel refresh self._active_instances: Set[_WarmedInstanceKey] = set() # keep track of table objects associated with each instance # only remove instance from _active_instances when all associated tables remove it self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() - self._channel_refresh_tasks: list[asyncio.Task[None]] = [] + self._channel_refresh_tasks: list[CrossSync.Task[None]] = [] + self._executor = ( + concurrent.futures.ThreadPoolExecutor() if not CrossSync.is_async else None + ) if self._emulator_host is not None: # connect to an emulator host warnings.warn( @@ -169,7 +196,7 @@ def __init__( RuntimeWarning, stacklevel=2, ) - self.transport._grpc_channel = PooledChannel( + self.transport._grpc_channel = CrossSync.PooledChannel( pool_size=pool_size, host=self._emulator_host, insecure=True, @@ -194,7 +221,10 @@ def _client_version() -> str: """ Helper function to return the client version string for this client """ - return f"{google.cloud.bigtable.__version__}-data-async" + version_str = f"{google.cloud.bigtable.__version__}-data" + if CrossSync.is_async: + version_str += "-async" + return version_str def _start_background_channel_refresh(self) -> None: """ @@ -203,31 +233,41 @@ def _start_background_channel_refresh(self) -> None: Raises: RuntimeError: if not called in an asyncio event loop """ - if not self._channel_refresh_tasks and not self._emulator_host: - # raise RuntimeError if there is no event loop - asyncio.get_running_loop() + if ( + not self._channel_refresh_tasks + and not self._emulator_host + and not self._is_closed.is_set() + ): + # raise error if not in an event loop in async client + CrossSync.verify_async_event_loop() for channel_idx in range(self.transport.pool_size): - refresh_task = asyncio.create_task(self._manage_channel(channel_idx)) - if sys.version_info >= (3, 8): - # task names supported in Python 3.8+ - refresh_task.set_name( - f"{self.__class__.__name__} channel refresh {channel_idx}" - ) + refresh_task = CrossSync.create_task( + self._manage_channel, + channel_idx, + sync_executor=self._executor, + task_name=f"{self.__class__.__name__} channel refresh {channel_idx}", + ) self._channel_refresh_tasks.append(refresh_task) - async def close(self, timeout: float = 2.0): + @CrossSync.convert + async def close(self, timeout: float | None = 2.0): """ Cancel all background tasks """ + self._is_closed.set() for task in self._channel_refresh_tasks: task.cancel() - group = asyncio.gather(*self._channel_refresh_tasks, return_exceptions=True) - await asyncio.wait_for(group, timeout=timeout) - await self.transport.close() + CrossSync.rm_aio(await self.transport.close()) + if self._executor: + self._executor.shutdown(wait=False) + CrossSync.rm_aio( + await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) + ) self._channel_refresh_tasks = [] + @CrossSync.convert async def _ping_and_warm_instances( - self, channel: grpc.aio.Channel, instance_key: _WarmedInstanceKey | None = None + self, channel: Channel, instance_key: _WarmedInstanceKey | None = None ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel @@ -248,8 +288,9 @@ async def _ping_and_warm_instances( request_serializer=PingAndWarmRequest.serialize, ) # prepare list of coroutines to run - tasks = [ - ping_rpc( + partial_list = [ + partial( + ping_rpc, request={"name": instance_name, "app_profile_id": app_profile_id}, metadata=[ ( @@ -261,11 +302,14 @@ async def _ping_and_warm_instances( ) for (instance_name, table_name, app_profile_id) in instance_list ] - # execute coroutines in parallel - result_list = await asyncio.gather(*tasks, return_exceptions=True) - # return None in place of empty successful responses + result_list = CrossSync.rm_aio( + await CrossSync.gather_partials( + partial_list, return_exceptions=True, sync_executor=self._executor + ) + ) return [r or None for r in result_list] + @CrossSync.convert async def _manage_channel( self, channel_idx: int, @@ -299,22 +343,37 @@ async def _manage_channel( if next_sleep > 0: # warm the current channel immediately channel = self.transport.channels[channel_idx] - await self._ping_and_warm_instances(channel) + CrossSync.rm_aio(await self._ping_and_warm_instances(channel)) # continuously refresh the channel every `refresh_interval` seconds - while True: - await asyncio.sleep(next_sleep) + while not self._is_closed.is_set(): + CrossSync.rm_aio( + await CrossSync.event_wait( + self._is_closed, + next_sleep, + async_break_early=False, # no need to interrupt sleep. Task will be cancelled on close + ) + ) + if self._is_closed.is_set(): + # don't refresh if client is closed + break # prepare new channel for use new_channel = self.transport.grpc_channel._create_channel() - await self._ping_and_warm_instances(new_channel) + CrossSync.rm_aio(await self._ping_and_warm_instances(new_channel)) # cycle channel out of use, with long grace window before closure - start_timestamp = time.time() - await self.transport.replace_channel( - channel_idx, grace=grace_period, swap_sleep=10, new_channel=new_channel + start_timestamp = time.monotonic() + CrossSync.rm_aio( + await self.transport.replace_channel( + channel_idx, + grace=grace_period, + new_channel=new_channel, + event=self._is_closed, + ) ) # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) - next_sleep = next_refresh - (time.time() - start_timestamp) + next_sleep = next_refresh - (time.monotonic() - start_timestamp) + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: """ Registers an instance with the client, and warms the channel pool @@ -340,11 +399,14 @@ async def _register_instance(self, instance_id: str, owner: TableAsync) -> None: # refresh tasks already running # call ping and warm on all existing channels for channel in self.transport.channels: - await self._ping_and_warm_instances(channel, instance_key) + CrossSync.rm_aio( + await self._ping_and_warm_instances(channel, instance_key) + ) else: # refresh tasks aren't active. start them as background tasks self._start_background_channel_refresh() + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) async def _remove_instance_registration( self, instance_id: str, owner: TableAsync ) -> bool: @@ -375,6 +437,7 @@ async def _remove_instance_registration( except KeyError: return False + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: """ Returns a table instance for making data API requests. All arguments are passed @@ -416,15 +479,20 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs """ return TableAsync(self, instance_id, table_id, *args, **kwargs) + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): self._start_background_channel_refresh() return self + @CrossSync.convert(sync_name="__exit__", replace_symbols={"__aexit__": "__exit__"}) async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) + CrossSync.rm_aio(await self.close()) + CrossSync.rm_aio(await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb)) +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync.client.Table", add_mapping_for_name="Table" +) class TableAsync: """ Main Data API surface @@ -433,6 +501,9 @@ class TableAsync: each call """ + @CrossSync.convert( + replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} + ) def __init__( self, client: BigtableDataClientAsync, @@ -541,17 +612,19 @@ def __init__( default_mutate_rows_retryable_errors or () ) self.default_retryable_errors = default_retryable_errors or () - - # raises RuntimeError if called outside of an async context (no running event loop) try: - self._register_instance_task = asyncio.create_task( - self.client._register_instance(instance_id, self) + self._register_instance_future = CrossSync.create_task( + self.client._register_instance, + self.instance_id, + self, + sync_executor=self.client._executor, ) except RuntimeError as e: raise RuntimeError( f"{self.__class__.__name__} must be created within an async event loop context." ) from e + @CrossSync.convert(replace_symbols={"AsyncIterable": "Iterable"}) async def read_rows_stream( self, query: ReadRowsQuery, @@ -593,7 +666,7 @@ async def read_rows_stream( ) retryable_excs = _get_retryable_errors(retryable_errors, self) - row_merger = _ReadRowsOperationAsync( + row_merger = CrossSync._ReadRowsOperation( query, self, operation_timeout=operation_timeout, @@ -602,6 +675,7 @@ async def read_rows_stream( ) return row_merger.start_operation() + @CrossSync.convert async def read_rows( self, query: ReadRowsQuery, @@ -641,14 +715,17 @@ async def read_rows( from any retries that failed google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error """ - row_generator = await self.read_rows_stream( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, + row_generator = CrossSync.rm_aio( + await self.read_rows_stream( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) ) - return [row async for row in row_generator] + return CrossSync.rm_aio([row async for row in row_generator]) + @CrossSync.convert async def read_row( self, row_key: str | bytes, @@ -688,16 +765,19 @@ async def read_row( if row_key is None: raise ValueError("row_key must be string or bytes") query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) - results = await self.read_rows( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, + results = CrossSync.rm_aio( + await self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) ) if len(results) == 0: return None return results[0] + @CrossSync.convert async def read_rows_sharded( self, sharded_query: ShardedQuery, @@ -748,25 +828,35 @@ async def read_rows_sharded( ) # limit the number of concurrent requests using a semaphore - concurrency_sem = asyncio.Semaphore(_CONCURRENCY_LIMIT) + concurrency_sem = CrossSync.Semaphore(_CONCURRENCY_LIMIT) async def read_rows_with_semaphore(query): - async with concurrency_sem: + async with CrossSync.rm_aio(concurrency_sem): # calculate new timeout based on time left in overall operation shard_timeout = next(rpc_timeout_generator) if shard_timeout <= 0: raise DeadlineExceeded( "Operation timeout exceeded before starting query" ) - return await self.read_rows( - query, - operation_timeout=shard_timeout, - attempt_timeout=min(attempt_timeout, shard_timeout), - retryable_errors=retryable_errors, + return CrossSync.rm_aio( + await self.read_rows( + query, + operation_timeout=shard_timeout, + attempt_timeout=min(attempt_timeout, shard_timeout), + retryable_errors=retryable_errors, + ) ) - routine_list = [read_rows_with_semaphore(query) for query in sharded_query] - batch_result = await asyncio.gather(*routine_list, return_exceptions=True) + routine_list = [ + partial(read_rows_with_semaphore, query) for query in sharded_query + ] + batch_result = CrossSync.rm_aio( + await CrossSync.gather_partials( + routine_list, + return_exceptions=True, + sync_executor=self.client._executor, + ) + ) # collect results and errors error_dict = {} @@ -793,6 +883,7 @@ async def read_rows_with_semaphore(query): ) return results_list + @CrossSync.convert async def row_exists( self, row_key: str | bytes, @@ -833,14 +924,17 @@ async def row_exists( limit_filter = CellsRowLimitFilter(1) chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) - results = await self.read_rows( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, + results = CrossSync.rm_aio( + await self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) ) return len(results) > 0 + @CrossSync.convert async def sample_row_keys( self, *, @@ -896,23 +990,30 @@ async def sample_row_keys( metadata = _make_metadata(self.table_name, self.app_profile_id) async def execute_rpc(): - results = await self.client._gapic_client.sample_row_keys( - table_name=self.table_name, - app_profile_id=self.app_profile_id, - timeout=next(attempt_timeout_gen), - metadata=metadata, - retry=None, + results = CrossSync.rm_aio( + await self.client._gapic_client.sample_row_keys( + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=next(attempt_timeout_gen), + metadata=metadata, + retry=None, + ) + ) + return CrossSync.rm_aio( + [(s.row_key, s.offset_bytes) async for s in results] ) - return [(s.row_key, s.offset_bytes) async for s in results] - return await retries.retry_target_async( - execute_rpc, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, + return CrossSync.rm_aio( + await CrossSync.retry_target( + execute_rpc, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) ) + @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) def mutations_batcher( self, *, @@ -950,7 +1051,7 @@ def mutations_batcher( Returns: MutationsBatcherAsync: a MutationsBatcherAsync context manager that can batch requests """ - return MutationsBatcherAsync( + return CrossSync.MutationsBatcher( self, flush_interval=flush_interval, flush_limit_mutation_count=flush_limit_mutation_count, @@ -962,6 +1063,7 @@ def mutations_batcher( batch_retryable_errors=batch_retryable_errors, ) + @CrossSync.convert async def mutate_row( self, row_key: str | bytes, @@ -1032,14 +1134,17 @@ async def mutate_row( metadata=_make_metadata(self.table_name, self.app_profile_id), retry=None, ) - return await retries.retry_target_async( - target, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, + return CrossSync.rm_aio( + await CrossSync.retry_target( + target, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) ) + @CrossSync.convert async def bulk_mutate_rows( self, mutation_entries: list[RowMutationEntry], @@ -1085,7 +1190,7 @@ async def bulk_mutate_rows( ) retryable_excs = _get_retryable_errors(retryable_errors, self) - operation = _MutateRowsOperationAsync( + operation = CrossSync._MutateRowsOperation( self.client._gapic_client, self, mutation_entries, @@ -1093,8 +1198,9 @@ async def bulk_mutate_rows( attempt_timeout, retryable_exceptions=retryable_excs, ) - await operation.start() + CrossSync.rm_aio(await operation.start()) + @CrossSync.convert async def check_and_mutate_row( self, row_key: str | bytes, @@ -1148,19 +1254,24 @@ async def check_and_mutate_row( false_case_mutations = [false_case_mutations] false_case_list = [m._to_pb() for m in false_case_mutations or []] metadata = _make_metadata(self.table_name, self.app_profile_id) - result = await self.client._gapic_client.check_and_mutate_row( - true_mutations=true_case_list, - false_mutations=false_case_list, - predicate_filter=predicate._to_pb() if predicate is not None else None, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, + result = CrossSync.rm_aio( + await self.client._gapic_client.check_and_mutate_row( + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) ) return result.predicate_matched + @CrossSync.convert async def read_modify_write_row( self, row_key: str | bytes, @@ -1199,25 +1310,34 @@ async def read_modify_write_row( if not rules: raise ValueError("rules must contain at least one item") metadata = _make_metadata(self.table_name, self.app_profile_id) - result = await self.client._gapic_client.read_modify_write_row( - rules=[rule._to_pb() for rule in rules], - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, + result = CrossSync.rm_aio( + await self.client._gapic_client.read_modify_write_row( + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) ) # construct Row from result return Row._from_pb(result.row) + @CrossSync.convert async def close(self): """ Called to close the Table instance and release any resources held by it. """ - self._register_instance_task.cancel() - await self.client._remove_instance_registration(self.instance_id, self) + if self._register_instance_future: + self._register_instance_future.cancel() + CrossSync.rm_aio( + await self.client._remove_instance_registration(self.instance_id, self) + ) + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """ Implement async context manager protocol @@ -1225,9 +1345,11 @@ async def __aenter__(self): Ensure registration task has time to run, so that grpc channels will be warmed for the specified instance """ - await self._register_instance_task + if self._register_instance_future: + CrossSync.rm_aio(await self._register_instance_future) return self + @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc_val, exc_tb): """ Implement async context manager protocol @@ -1235,4 +1357,4 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): Unregister this instance with the client, so that grpc channels will no longer be warmed """ - await self.close() + CrossSync.rm_aio(await self.close()) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 76d13f00b..7a6def9e4 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -14,32 +14,37 @@ # from __future__ import annotations -from typing import Any, Sequence, TYPE_CHECKING -import asyncio +from typing import Sequence, TYPE_CHECKING import atexit import warnings from collections import deque +import concurrent.futures -from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.exceptions import FailedMutationEntryError from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _MB_SIZE -from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync -from google.cloud.bigtable.data._async._mutate_rows import ( +from google.cloud.bigtable.data.mutations import ( _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, ) from google.cloud.bigtable.data.mutations import Mutation +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + if TYPE_CHECKING: - from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data.mutations import RowMutationEntry -# used to make more readable default values -_MB_SIZE = 1024 * 1024 + if CrossSync.is_async: + from google.cloud.bigtable.data._async.client import TableAsync +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync.mutations_batcher._FlowControl", + add_mapping_for_name="_FlowControl", +) class _FlowControlAsync: """ Manages flow control for batched mutations. Mutations are registered against @@ -70,7 +75,7 @@ def __init__( raise ValueError("max_mutation_count must be greater than 0") if self._max_mutation_bytes < 1: raise ValueError("max_mutation_bytes must be greater than 0") - self._capacity_condition = asyncio.Condition() + self._capacity_condition = CrossSync.Condition() self._in_flight_mutation_count = 0 self._in_flight_mutation_bytes = 0 @@ -96,6 +101,7 @@ def _has_capacity(self, additional_count: int, additional_size: int) -> bool: new_count = self._in_flight_mutation_count + additional_count return new_size <= acceptable_size and new_count <= acceptable_count + @CrossSync.convert async def remove_from_flow( self, mutations: RowMutationEntry | list[RowMutationEntry] ) -> None: @@ -114,9 +120,10 @@ async def remove_from_flow( self._in_flight_mutation_count -= total_count self._in_flight_mutation_bytes -= total_size # notify any blocked requests that there is additional capacity - async with self._capacity_condition: + async with CrossSync.rm_aio(self._capacity_condition): self._capacity_condition.notify_all() + @CrossSync.convert async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): """ Generator function that registers mutations with flow control. As mutations @@ -139,7 +146,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] start_idx = end_idx batch_mutation_count = 0 # fill up batch until we hit capacity - async with self._capacity_condition: + async with CrossSync.rm_aio(self._capacity_condition): while end_idx < len(mutations): next_entry = mutations[end_idx] next_size = next_entry.size() @@ -160,12 +167,19 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] break else: # batch is empty. Block until we have capacity - await self._capacity_condition.wait_for( - lambda: self._has_capacity(next_count, next_size) + CrossSync.rm_aio( + await self._capacity_condition.wait_for( + lambda: self._has_capacity(next_count, next_size) + ) ) yield mutations[start_idx:end_idx] +@CrossSync.export_sync( + path="google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", + mypy_ignore=["unreachable"], + add_mapping_for_name="MutationsBatcher", +) class MutationsBatcherAsync: """ Allows users to send batches using context manager API: @@ -197,9 +211,10 @@ class MutationsBatcherAsync: Defaults to the Table's default_mutate_rows_retryable_errors. """ + @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def __init__( self, - table: "TableAsync", + table: TableAsync, *, flush_interval: float | None = 5, flush_limit_mutation_count: int | None = 1000, @@ -218,11 +233,11 @@ def __init__( batch_retryable_errors, table ) - self.closed: bool = False + self._closed = CrossSync.Event() self._table = table self._staged_entries: list[RowMutationEntry] = [] self._staged_count, self._staged_bytes = 0, 0 - self._flow_control = _FlowControlAsync( + self._flow_control = CrossSync._FlowControl( flow_control_max_mutation_count, flow_control_max_bytes ) self._flush_limit_bytes = flush_limit_bytes @@ -231,8 +246,15 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - self._flush_timer = self._start_flush_timer(flush_interval) - self._flush_jobs: set[asyncio.Future[None]] = set() + self._sync_executor = ( + concurrent.futures.ThreadPoolExecutor(max_workers=8) + if not CrossSync.is_async + else None + ) + self._flush_timer = CrossSync.create_task( + self._timer_routine, flush_interval, sync_executor=self._sync_executor + ) + self._flush_jobs: set[CrossSync.Future[None]] = set() # MutationExceptionGroup reports number of successful entries along with failures self._entries_processed_since_last_raise: int = 0 self._exceptions_since_last_raise: int = 0 @@ -245,7 +267,8 @@ def __init__( # clean up on program exit atexit.register(self._on_exit) - def _start_flush_timer(self, interval: float | None) -> asyncio.Future[None]: + @CrossSync.convert + async def _timer_routine(self, interval: float | None) -> None: """ Set up a background task to flush the batcher every interval seconds @@ -254,27 +277,20 @@ def _start_flush_timer(self, interval: float | None) -> asyncio.Future[None]: Args: flush_interval: Automatically flush every flush_interval seconds. If None, no time-based flushing is performed. - Returns: - asyncio.Future[None]: future representing the background task """ - if interval is None or self.closed: - empty_future: asyncio.Future[None] = asyncio.Future() - empty_future.set_result(None) - return empty_future - - async def timer_routine(self, interval: float): - """ - Triggers new flush tasks every `interval` seconds - """ - while not self.closed: - await asyncio.sleep(interval) - # add new flush task to list - if not self.closed and self._staged_entries: - self._schedule_flush() - - timer_task = asyncio.create_task(timer_routine(self, interval)) - return timer_task + if not interval or interval <= 0: + return None + while not self._closed.is_set(): + # wait until interval has passed, or until closed + CrossSync.rm_aio( + await CrossSync.event_wait( + self._closed, timeout=interval, async_break_early=False + ) + ) + if not self._closed.is_set() and self._staged_entries: + self._schedule_flush() + @CrossSync.convert async def append(self, mutation_entry: RowMutationEntry): """ Add a new set of mutations to the internal queue @@ -286,7 +302,7 @@ async def append(self, mutation_entry: RowMutationEntry): ValueError: if an invalid mutation type is added """ # TODO: return a future to track completion of this entry - if self.closed: + if self._closed.is_set(): raise RuntimeError("Cannot append to closed MutationsBatcher") if isinstance(mutation_entry, Mutation): # type: ignore raise ValueError( @@ -302,25 +318,29 @@ async def append(self, mutation_entry: RowMutationEntry): ): self._schedule_flush() # yield to the event loop to allow flush to run - await asyncio.sleep(0) + CrossSync.rm_aio(await CrossSync.yield_to_event_loop()) - def _schedule_flush(self) -> asyncio.Future[None] | None: + def _schedule_flush(self) -> CrossSync.Future[None] | None: """ Update the flush task to include the latest staged entries Returns: - asyncio.Future[None] | None: + Future[None] | None: future representing the background task, if started """ if self._staged_entries: entries, self._staged_entries = self._staged_entries, [] self._staged_count, self._staged_bytes = 0, 0 - new_task = self._create_bg_task(self._flush_internal, entries) - new_task.add_done_callback(self._flush_jobs.remove) - self._flush_jobs.add(new_task) + new_task = CrossSync.create_task( + self._flush_internal, entries, sync_executor=self._sync_executor + ) + if not new_task.done(): + self._flush_jobs.add(new_task) + new_task.add_done_callback(self._flush_jobs.remove) return new_task return None + @CrossSync.convert async def _flush_internal(self, new_entries: list[RowMutationEntry]): """ Flushes a set of mutations to the server, and updates internal state @@ -329,16 +349,23 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): new_entries list of RowMutationEntry objects to flush """ # flush new entries - in_process_requests: list[asyncio.Future[list[FailedMutationEntryError]]] = [] - async for batch in self._flow_control.add_to_flow(new_entries): - batch_task = self._create_bg_task(self._execute_mutate_rows, batch) + in_process_requests: list[CrossSync.Future[list[FailedMutationEntryError]]] = [] + async for batch in CrossSync.rm_aio( + self._flow_control.add_to_flow(new_entries) + ): + batch_task = CrossSync.create_task( + self._execute_mutate_rows, batch, sync_executor=self._sync_executor + ) in_process_requests.append(batch_task) # wait for all inflight requests to complete - found_exceptions = await self._wait_for_batch_results(*in_process_requests) + found_exceptions = CrossSync.rm_aio( + await self._wait_for_batch_results(*in_process_requests) + ) # update exception data to reflect any new errors self._entries_processed_since_last_raise += len(new_entries) self._add_exceptions(found_exceptions) + @CrossSync.convert async def _execute_mutate_rows( self, batch: list[RowMutationEntry] ) -> list[FailedMutationEntryError]: @@ -355,7 +382,7 @@ async def _execute_mutate_rows( FailedMutationEntryError objects will not contain index information """ try: - operation = _MutateRowsOperationAsync( + operation = CrossSync._MutateRowsOperation( self._table.client._gapic_client, self._table, batch, @@ -363,7 +390,7 @@ async def _execute_mutate_rows( attempt_timeout=self._attempt_timeout, retryable_exceptions=self._retryable_errors, ) - await operation.start() + CrossSync.rm_aio(await operation.start()) except MutationsExceptionGroup as e: # strip index information from exceptions, since it is not useful in a batch context for subexc in e.exceptions: @@ -371,7 +398,7 @@ async def _execute_mutate_rows( return list(e.exceptions) finally: # mark batch as complete in flow control - await self._flow_control.remove_from_flow(batch) + CrossSync.rm_aio(await self._flow_control.remove_from_flow(batch)) return [] def _add_exceptions(self, excs: list[Exception]): @@ -419,31 +446,41 @@ def _raise_exceptions(self): entry_count=entry_count, ) + @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): """Allow use of context manager API""" return self + @CrossSync.convert(sync_name="__exit__") async def __aexit__(self, exc_type, exc, tb): """ Allow use of context manager API. Flushes the batcher and cleans up resources. """ - await self.close() + CrossSync.rm_aio(await self.close()) + + @property + def closed(self) -> bool: + """ + Returns: + - True if the batcher is closed, False otherwise + """ + return self._closed.is_set() + @CrossSync.convert async def close(self): """ Flush queue and clean up resources """ - self.closed = True + self._closed.set() self._flush_timer.cancel() self._schedule_flush() - if self._flush_jobs: - await asyncio.gather(*self._flush_jobs, return_exceptions=True) - try: - await self._flush_timer - except asyncio.CancelledError: - pass + CrossSync.rm_aio(await CrossSync.wait([*self._flush_jobs, self._flush_timer])) + # shut down executor + if self._sync_executor: + with self._sync_executor: + self._sync_executor.shutdown(wait=True) atexit.unregister(self._on_exit) # raise unreported exceptions self._raise_exceptions() @@ -452,32 +489,17 @@ def _on_exit(self): """ Called when program is exited. Raises warning if unflushed mutations remain """ - if not self.closed and self._staged_entries: + if not self._closed.is_set() and self._staged_entries: warnings.warn( f"MutationsBatcher for table {self._table.table_name} was not closed. " f"{len(self._staged_entries)} Unflushed mutations will not be sent to the server." ) @staticmethod - def _create_bg_task(func, *args, **kwargs) -> asyncio.Future[Any]: - """ - Create a new background task, and return a future - - This method wraps asyncio to make it easier to maintain subclasses - with different concurrency models. - - Args: - func: function to execute in background task - *args: positional arguments to pass to func - **kwargs: keyword arguments to pass to func - Returns: - asyncio.Future: Future object representing the background task - """ - return asyncio.create_task(func(*args, **kwargs)) - - @staticmethod + @CrossSync.convert async def _wait_for_batch_results( - *tasks: asyncio.Future[list[FailedMutationEntryError]] | asyncio.Future[None], + *tasks: CrossSync.Future[list[FailedMutationEntryError]] + | CrossSync.Future[None], ) -> list[Exception]: """ Takes in a list of futures representing _execute_mutate_rows tasks, @@ -494,19 +516,19 @@ async def _wait_for_batch_results( """ if not tasks: return [] - all_results = await asyncio.gather(*tasks, return_exceptions=True) - found_errors = [] - for result in all_results: - if isinstance(result, Exception): - # will receive direct Exception objects if request task fails - found_errors.append(result) - elif isinstance(result, BaseException): - # BaseException not expected from grpc calls. Raise immediately - raise result - elif result: - # completed requests will return a list of FailedMutationEntryError - for e in result: - # strip index information - e.index = None - found_errors.extend(result) - return found_errors + exceptions: list[Exception] = [] + for task in tasks: + if CrossSync.is_async: + # futures don't need to be awaited in sync mode + CrossSync.rm_aio(await task) + try: + exc_list = task.result() + if exc_list: + # expect a list of FailedMutationEntryError objects + for exc in exc_list: + # strip index information + exc.index = None + exceptions.extend(exc_list) + except Exception as e: + exceptions.append(e) + return exceptions diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index a8fba9ef1..a8113cc4a 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -48,6 +48,9 @@ "_WarmedInstanceKey", ["instance_name", "table_name", "app_profile_id"] ) +# used to make more readable default values +_MB_SIZE = 1024 * 1024 + # enum used on method calls when table defaults should be used class TABLE_DEFAULT(enum.Enum): diff --git a/google/cloud/bigtable/data/exceptions.py b/google/cloud/bigtable/data/exceptions.py index 8d97640aa..8065ed9d1 100644 --- a/google/cloud/bigtable/data/exceptions.py +++ b/google/cloud/bigtable/data/exceptions.py @@ -41,6 +41,21 @@ class _RowSetComplete(Exception): pass +class _ResetRow(Exception): # noqa: F811 + """ + Internal exception for _ReadRowsOperation + + Denotes that the server sent a reset_row marker, telling the client to drop + all previous chunks for row_key and re-read from the beginning. + + Args: + chunk: the reset_row chunk + """ + + def __init__(self, chunk): + self.chunk = chunk + + class _MutateRowsIncomplete(RuntimeError): """ Exception raised when a mutate_rows call has unfinished work. diff --git a/google/cloud/bigtable/data/mutations.py b/google/cloud/bigtable/data/mutations.py index 335a15e12..2f4e441ed 100644 --- a/google/cloud/bigtable/data/mutations.py +++ b/google/cloud/bigtable/data/mutations.py @@ -366,3 +366,15 @@ def _from_dict(cls, input_dict: dict[str, Any]) -> RowMutationEntry: Mutation._from_dict(mutation) for mutation in input_dict["mutations"] ], ) + + +@dataclass +class _EntryWithProto: + """ + A dataclass to hold a RowMutationEntry and its corresponding proto representation. + + Used in _MutateRowsOperation to avoid repeated conversion of RowMutationEntry to proto. + """ + + entry: RowMutationEntry + proto: types_pb.MutateRowsRequest.Entry diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py index 372e5796d..864b4ecc2 100644 --- a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py +++ b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc_asyncio.py @@ -150,7 +150,7 @@ async def wait_for_state_change(self, last_observed_state): raise NotImplementedError() async def replace_channel( - self, channel_idx, grace=None, swap_sleep=1, new_channel=None + self, channel_idx, grace=1, new_channel=None, event=None ) -> aio.Channel: """ Replaces a channel in the pool with a fresh one. @@ -160,13 +160,14 @@ async def replace_channel( Args: channel_idx(int): the channel index in the pool to replace - grace(Optional[float]): The time to wait until all active RPCs are - finished. If a grace period is not specified (by passing None for + grace(Optional[float]): The time to wait for active RPCs to + finish. If a grace period is not specified (by passing None for grace), all existing RPCs are cancelled immediately. - swap_sleep(Optional[float]): The number of seconds to sleep in between - replacing channels and closing the old one + If event is set at close time, grace is ignored new_channel(grpc.aio.Channel): a new channel to insert into the pool at `channel_idx`. If `None`, a new channel will be created. + event(Optional[threading.Event]): an event to signal when the + replacement should be aborted. If set, grace is ignored. """ if channel_idx >= len(self._pool) or channel_idx < 0: raise ValueError( @@ -176,7 +177,8 @@ async def replace_channel( new_channel = self._create_channel() old_channel = self._pool[channel_idx] self._pool[channel_idx] = new_channel - await asyncio.sleep(swap_sleep) + if event is not None and not event.is_set(): + grace = None await old_channel.close(grace=grace) return new_channel @@ -400,7 +402,7 @@ def channels(self) -> List[grpc.Channel]: return self._grpc_channel._pool async def replace_channel( - self, channel_idx, grace=None, swap_sleep=1, new_channel=None + self, channel_idx, grace=1, new_channel=None, event=None ) -> aio.Channel: """ Replaces a channel in the pool with a fresh one. @@ -410,16 +412,17 @@ async def replace_channel( Args: channel_idx(int): the channel index in the pool to replace - grace(Optional[float]): The time to wait until all active RPCs are - finished. If a grace period is not specified (by passing None for + grace(Optional[float]): The time to wait for active RPCs to + finish. If a grace period is not specified (by passing None for grace), all existing RPCs are cancelled immediately. - swap_sleep(Optional[float]): The number of seconds to sleep in between - replacing channels and closing the old one + If event is set at close time, grace is ignored new_channel(grpc.aio.Channel): a new channel to insert into the pool at `channel_idx`. If `None`, a new channel will be created. + event(Optional[threading.Event]): an event to signal when the + replacement should be aborted. If set, grace is ignored. """ return await self._grpc_channel.replace_channel( - channel_idx, grace, swap_sleep, new_channel + channel_idx=channel_idx, grace=grace, new_channel=new_channel, event=event ) diff --git a/tests/system/data/__init__.py b/tests/system/data/__init__.py index 89a37dc92..f2952b2cd 100644 --- a/tests/system/data/__init__.py +++ b/tests/system/data/__init__.py @@ -13,3 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +TEST_FAMILY = "test-family" +TEST_FAMILY_2 = "test-family-2" diff --git a/tests/system/data/setup_fixtures.py b/tests/system/data/setup_fixtures.py index 77086b7f3..3b5a0af06 100644 --- a/tests/system/data/setup_fixtures.py +++ b/tests/system/data/setup_fixtures.py @@ -17,20 +17,10 @@ """ import pytest -import pytest_asyncio import os -import asyncio import uuid -@pytest.fixture(scope="session") -def event_loop(): - loop = asyncio.get_event_loop() - yield loop - loop.stop() - loop.close() - - @pytest.fixture(scope="session") def admin_client(): """ @@ -150,22 +140,7 @@ def table_id( print(f"Table {init_table_id} not found, skipping deletion") -@pytest_asyncio.fixture(scope="session") -async def client(): - from google.cloud.bigtable.data import BigtableDataClientAsync - - project = os.getenv("GOOGLE_CLOUD_PROJECT") or None - async with BigtableDataClientAsync(project=project, pool_size=4) as client: - yield client - - @pytest.fixture(scope="session") def project_id(client): """Returns the project ID from the client.""" yield client.project - - -@pytest_asyncio.fixture(scope="session") -async def table(client, table_id, instance_id): - async with client.get_table(instance_id, table_id) as table: - yield table diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py deleted file mode 100644 index 9fe208551..000000000 --- a/tests/system/data/test_system.py +++ /dev/null @@ -1,942 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import pytest_asyncio -import asyncio -import uuid -import os -from google.api_core import retry -from google.api_core.exceptions import ClientError - -from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE -from google.cloud.environment_vars import BIGTABLE_EMULATOR - -TEST_FAMILY = "test-family" -TEST_FAMILY_2 = "test-family-2" - - -@pytest.fixture(scope="session") -def column_family_config(): - """ - specify column families to create when creating a new test table - """ - from google.cloud.bigtable_admin_v2 import types - - return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} - - -@pytest.fixture(scope="session") -def init_table_id(): - """ - The table_id to use when creating a new test table - """ - return f"test-table-{uuid.uuid4().hex}" - - -@pytest.fixture(scope="session") -def cluster_config(project_id): - """ - Configuration for the clusters to use when creating a new instance - """ - from google.cloud.bigtable_admin_v2 import types - - cluster = { - "test-cluster": types.Cluster( - location=f"projects/{project_id}/locations/us-central1-b", - serve_nodes=1, - ) - } - return cluster - - -class TempRowBuilder: - """ - Used to add rows to a table for testing purposes. - """ - - def __init__(self, table): - self.rows = [] - self.table = table - - async def add_row( - self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" - ): - if isinstance(value, str): - value = value.encode("utf-8") - elif isinstance(value, int): - value = value.to_bytes(8, byteorder="big", signed=True) - request = { - "table_name": self.table.table_name, - "row_key": row_key, - "mutations": [ - { - "set_cell": { - "family_name": family, - "column_qualifier": qualifier, - "value": value, - } - } - ], - } - await self.table.client._gapic_client.mutate_row(request) - self.rows.append(row_key) - - async def delete_rows(self): - if self.rows: - request = { - "table_name": self.table.table_name, - "entries": [ - {"row_key": row, "mutations": [{"delete_from_row": {}}]} - for row in self.rows - ], - } - await self.table.client._gapic_client.mutate_rows(request) - - -@pytest.mark.usefixtures("table") -async def _retrieve_cell_value(table, row_key): - """ - Helper to read an individual row - """ - from google.cloud.bigtable.data import ReadRowsQuery - - row_list = await table.read_rows(ReadRowsQuery(row_keys=row_key)) - assert len(row_list) == 1 - row = row_list[0] - cell = row.cells[0] - return cell.value - - -async def _create_row_and_mutation( - table, temp_rows, *, start_value=b"start", new_value=b"new_value" -): - """ - Helper to create a new row, and a sample set_cell mutation to change its value - """ - from google.cloud.bigtable.data.mutations import SetCell - - row_key = uuid.uuid4().hex.encode() - family = TEST_FAMILY - qualifier = b"test-qualifier" - await temp_rows.add_row( - row_key, family=family, qualifier=qualifier, value=start_value - ) - # ensure cell is initialized - assert (await _retrieve_cell_value(table, row_key)) == start_value - - mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) - return row_key, mutation - - -@pytest_asyncio.fixture(scope="function") -async def temp_rows(table): - builder = TempRowBuilder(table) - yield builder - await builder.delete_rows() - - -@pytest.mark.usefixtures("table") -@pytest.mark.usefixtures("client") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=10) -@pytest.mark.asyncio -async def test_ping_and_warm_gapic(client, table): - """ - Simple ping rpc test - This test ensures channels are able to authenticate with backend - """ - request = {"name": table.instance_name} - await client._gapic_client.ping_and_warm(request) - - -@pytest.mark.usefixtures("table") -@pytest.mark.usefixtures("client") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_ping_and_warm(client, table): - """ - Test ping and warm from handwritten client - """ - try: - channel = client.transport._grpc_channel.pool[0] - except Exception: - # for sync client - channel = client.transport._grpc_channel - results = await client._ping_and_warm_instances(channel) - assert len(results) == 1 - assert results[0] is None - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -async def test_mutation_set_cell(table, temp_rows): - """ - Ensure cells can be set properly - """ - row_key = b"bulk_mutate" - new_value = uuid.uuid4().hex.encode() - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - await table.mutate_row(row_key, mutation) - - # ensure cell is updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - - -@pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" -) -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_sample_row_keys(client, table, temp_rows, column_split_config): - """ - Sample keys should return a single sample in small test tables - """ - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - - results = await table.sample_row_keys() - assert len(results) == len(column_split_config) + 1 - # first keys should match the split config - for idx in range(len(column_split_config)): - assert results[idx][0] == column_split_config[idx] - assert isinstance(results[idx][1], int) - # last sample should be empty key - assert results[-1][0] == b"" - assert isinstance(results[-1][1], int) - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_bulk_mutations_set_cell(client, table, temp_rows): - """ - Ensure cells can be set properly - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - - await table.bulk_mutate_rows([bulk_mutation]) - - # ensure cell is updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - - -@pytest.mark.asyncio -async def test_bulk_mutations_raise_exception(client, table): - """ - If an invalid mutation is passed, an exception should be raised - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell - from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup - from google.cloud.bigtable.data.exceptions import FailedMutationEntryError - - row_key = uuid.uuid4().hex.encode() - mutation = SetCell(family="nonexistent", qualifier=b"test-qualifier", new_value=b"") - bulk_mutation = RowMutationEntry(row_key, [mutation]) - - with pytest.raises(MutationsExceptionGroup) as exc: - await table.bulk_mutate_rows([bulk_mutation]) - assert len(exc.value.exceptions) == 1 - entry_error = exc.value.exceptions[0] - assert isinstance(entry_error, FailedMutationEntryError) - assert entry_error.index == 0 - assert entry_error.entry == bulk_mutation - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_mutations_batcher_context_manager(client, table, temp_rows): - """ - test batcher with context manager. Should flush on exit - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - row_key2, mutation2 = await _create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - async with table.mutations_batcher() as batcher: - await batcher.append(bulk_mutation) - await batcher.append(bulk_mutation2) - # ensure cell is updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - assert len(batcher._staged_entries) == 0 - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_mutations_batcher_timer_flush(client, table, temp_rows): - """ - batch should occur after flush_interval seconds - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - flush_interval = 0.1 - async with table.mutations_batcher(flush_interval=flush_interval) as batcher: - await batcher.append(bulk_mutation) - await asyncio.sleep(0) - assert len(batcher._staged_entries) == 1 - await asyncio.sleep(flush_interval + 0.1) - assert len(batcher._staged_entries) == 0 - # ensure cell is updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_mutations_batcher_count_flush(client, table, temp_rows): - """ - batch should flush after flush_limit_mutation_count mutations - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = await _create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - async with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: - await batcher.append(bulk_mutation) - assert len(batcher._flush_jobs) == 0 - # should be noop; flush not scheduled - assert len(batcher._staged_entries) == 1 - await batcher.append(bulk_mutation2) - # task should now be scheduled - assert len(batcher._flush_jobs) == 1 - await asyncio.gather(*batcher._flush_jobs) - assert len(batcher._staged_entries) == 0 - assert len(batcher._flush_jobs) == 0 - # ensure cells were updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - assert (await _retrieve_cell_value(table, row_key2)) == new_value2 - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_mutations_batcher_bytes_flush(client, table, temp_rows): - """ - batch should flush after flush_limit_bytes bytes - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = await _create_row_and_mutation( - table, temp_rows, new_value=new_value2 - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 - - async with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: - await batcher.append(bulk_mutation) - assert len(batcher._flush_jobs) == 0 - assert len(batcher._staged_entries) == 1 - await batcher.append(bulk_mutation2) - # task should now be scheduled - assert len(batcher._flush_jobs) == 1 - assert len(batcher._staged_entries) == 0 - # let flush complete - await asyncio.gather(*batcher._flush_jobs) - # ensure cells were updated - assert (await _retrieve_cell_value(table, row_key)) == new_value - assert (await _retrieve_cell_value(table, row_key2)) == new_value2 - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_mutations_batcher_no_flush(client, table, temp_rows): - """ - test with no flush requirements met - """ - from google.cloud.bigtable.data.mutations import RowMutationEntry - - new_value = uuid.uuid4().hex.encode() - start_value = b"unchanged" - row_key, mutation = await _create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value - ) - bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = await _create_row_and_mutation( - table, temp_rows, start_value=start_value, new_value=new_value - ) - bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) - - size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 - async with table.mutations_batcher( - flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 - ) as batcher: - await batcher.append(bulk_mutation) - assert len(batcher._staged_entries) == 1 - await batcher.append(bulk_mutation2) - # flush not scheduled - assert len(batcher._flush_jobs) == 0 - await asyncio.sleep(0.01) - assert len(batcher._staged_entries) == 2 - assert len(batcher._flush_jobs) == 0 - # ensure cells were not updated - assert (await _retrieve_cell_value(table, row_key)) == start_value - assert (await _retrieve_cell_value(table, row_key2)) == start_value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.parametrize( - "start,increment,expected", - [ - (0, 0, 0), - (0, 1, 1), - (0, -1, -1), - (1, 0, 1), - (0, -100, -100), - (0, 3000, 3000), - (10, 4, 14), - (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), - (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), - (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), - ], -) -@pytest.mark.asyncio -async def test_read_modify_write_row_increment( - client, table, temp_rows, start, increment, expected -): - """ - test read_modify_write_row - """ - from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) - - rule = IncrementRule(family, qualifier, increment) - result = await table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert len(result) == 1 - assert result[0].family == family - assert result[0].qualifier == qualifier - assert int(result[0]) == expected - # ensure that reading from server gives same value - assert (await _retrieve_cell_value(table, row_key)) == result[0].value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.parametrize( - "start,append,expected", - [ - (b"", b"", b""), - ("", "", b""), - (b"abc", b"123", b"abc123"), - (b"abc", "123", b"abc123"), - ("", b"1", b"1"), - (b"abc", "", b"abc"), - (b"hello", b"world", b"helloworld"), - ], -) -@pytest.mark.asyncio -async def test_read_modify_write_row_append( - client, table, temp_rows, start, append, expected -): - """ - test read_modify_write_row - """ - from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - await temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) - - rule = AppendValueRule(family, qualifier, append) - result = await table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert len(result) == 1 - assert result[0].family == family - assert result[0].qualifier == qualifier - assert result[0].value == expected - # ensure that reading from server gives same value - assert (await _retrieve_cell_value(table, row_key)) == result[0].value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_modify_write_row_chained(client, table, temp_rows): - """ - test read_modify_write_row with multiple rules - """ - from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule - from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - start_amount = 1 - increment_amount = 10 - await temp_rows.add_row( - row_key, value=start_amount, family=family, qualifier=qualifier - ) - rule = [ - IncrementRule(family, qualifier, increment_amount), - AppendValueRule(family, qualifier, "hello"), - AppendValueRule(family, qualifier, "world"), - AppendValueRule(family, qualifier, "!"), - ] - result = await table.read_modify_write_row(row_key, rule) - assert result.row_key == row_key - assert result[0].family == family - assert result[0].qualifier == qualifier - # result should be a bytes number string for the IncrementRules, followed by the AppendValueRule values - assert ( - result[0].value - == (start_amount + increment_amount).to_bytes(8, "big", signed=True) - + b"helloworld!" - ) - # ensure that reading from server gives same value - assert (await _retrieve_cell_value(table, row_key)) == result[0].value - - -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.parametrize( - "start_val,predicate_range,expected_result", - [ - (1, (0, 2), True), - (-1, (0, 2), False), - ], -) -@pytest.mark.asyncio -async def test_check_and_mutate( - client, table, temp_rows, start_val, predicate_range, expected_result -): - """ - test that check_and_mutate_row works applies the right mutations, and returns the right result - """ - from google.cloud.bigtable.data.mutations import SetCell - from google.cloud.bigtable.data.row_filters import ValueRangeFilter - - row_key = b"test-row-key" - family = TEST_FAMILY - qualifier = b"test-qualifier" - - await temp_rows.add_row( - row_key, value=start_val, family=family, qualifier=qualifier - ) - - false_mutation_value = b"false-mutation-value" - false_mutation = SetCell( - family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value - ) - true_mutation_value = b"true-mutation-value" - true_mutation = SetCell( - family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value - ) - predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) - result = await table.check_and_mutate_row( - row_key, - predicate, - true_case_mutations=true_mutation, - false_case_mutations=false_mutation, - ) - assert result == expected_result - # ensure cell is updated - expected_value = true_mutation_value if expected_result else false_mutation_value - assert (await _retrieve_cell_value(table, row_key)) == expected_value - - -@pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", -) -@pytest.mark.usefixtures("client") -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_check_and_mutate_empty_request(client, table): - """ - check_and_mutate with no true or fale mutations should raise an error - """ - from google.api_core import exceptions - - with pytest.raises(exceptions.InvalidArgument) as e: - await table.check_and_mutate_row( - b"row_key", None, true_case_mutations=None, false_case_mutations=None - ) - assert "No mutations provided" in str(e.value) - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_stream(table, temp_rows): - """ - Ensure that the read_rows_stream method works - """ - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - - # full table scan - generator = await table.read_rows_stream({}) - first_row = await generator.__anext__() - second_row = await generator.__anext__() - assert first_row.row_key == b"row_key_1" - assert second_row.row_key == b"row_key_2" - with pytest.raises(StopAsyncIteration): - await generator.__anext__() - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows(table, temp_rows): - """ - Ensure that the read_rows method works - """ - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - # full table scan - row_list = await table.read_rows({}) - assert len(row_list) == 2 - assert row_list[0].row_key == b"row_key_1" - assert row_list[1].row_key == b"row_key_2" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_sharded_simple(table, temp_rows): - """ - Test read rows sharded with two queries - """ - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) - query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) - row_list = await table.read_rows_sharded([query1, query2]) - assert len(row_list) == 4 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"c" - assert row_list[2].row_key == b"b" - assert row_list[3].row_key == b"d" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_sharded_from_sample(table, temp_rows): - """ - Test end-to-end sharding - """ - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - from google.cloud.bigtable.data.read_rows_query import RowRange - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - - table_shard_keys = await table.sample_row_keys() - query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) - shard_queries = query.shard(table_shard_keys) - row_list = await table.read_rows_sharded(shard_queries) - assert len(row_list) == 3 - assert row_list[0].row_key == b"b" - assert row_list[1].row_key == b"c" - assert row_list[2].row_key == b"d" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_sharded_filters_limits(table, temp_rows): - """ - Test read rows sharded with filters and limits - """ - from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - - label_filter1 = ApplyLabelFilter("first") - label_filter2 = ApplyLabelFilter("second") - query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) - query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) - row_list = await table.read_rows_sharded([query1, query2]) - assert len(row_list) == 3 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"b" - assert row_list[2].row_key == b"d" - assert row_list[0][0].labels == ["first"] - assert row_list[1][0].labels == ["second"] - assert row_list[2][0].labels == ["second"] - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_range_query(table, temp_rows): - """ - Ensure that the read_rows method works - """ - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable.data import RowRange - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - # full table scan - query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) - row_list = await table.read_rows(query) - assert len(row_list) == 2 - assert row_list[0].row_key == b"b" - assert row_list[1].row_key == b"c" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_single_key_query(table, temp_rows): - """ - Ensure that the read_rows method works with specified query - """ - from google.cloud.bigtable.data import ReadRowsQuery - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - # retrieve specific keys - query = ReadRowsQuery(row_keys=[b"a", b"c"]) - row_list = await table.read_rows(query) - assert len(row_list) == 2 - assert row_list[0].row_key == b"a" - assert row_list[1].row_key == b"c" - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.asyncio -async def test_read_rows_with_filter(table, temp_rows): - """ - ensure filters are applied - """ - from google.cloud.bigtable.data import ReadRowsQuery - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - await temp_rows.add_row(b"a") - await temp_rows.add_row(b"b") - await temp_rows.add_row(b"c") - await temp_rows.add_row(b"d") - # retrieve keys with filter - expected_label = "test-label" - row_filter = ApplyLabelFilter(expected_label) - query = ReadRowsQuery(row_filter=row_filter) - row_list = await table.read_rows(query) - assert len(row_list) == 4 - for row in row_list: - assert row[0].labels == [expected_label] - - -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_rows_stream_close(table, temp_rows): - """ - Ensure that the read_rows_stream can be closed - """ - from google.cloud.bigtable.data import ReadRowsQuery - - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - # full table scan - query = ReadRowsQuery() - generator = await table.read_rows_stream(query) - # grab first row - first_row = await generator.__anext__() - assert first_row.row_key == b"row_key_1" - # close stream early - await generator.aclose() - with pytest.raises(StopAsyncIteration): - await generator.__anext__() - - -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_row(table, temp_rows): - """ - Test read_row (single row helper) - """ - from google.cloud.bigtable.data import Row - - await temp_rows.add_row(b"row_key_1", value=b"value") - row = await table.read_row(b"row_key_1") - assert isinstance(row, Row) - assert row.row_key == b"row_key_1" - assert row.cells[0].value == b"value" - - -@pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", -) -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_row_missing(table): - """ - Test read_row when row does not exist - """ - from google.api_core import exceptions - - row_key = "row_key_not_exist" - result = await table.read_row(row_key) - assert result is None - with pytest.raises(exceptions.InvalidArgument) as e: - await table.read_row("") - assert "Row keys must be non-empty" in str(e) - - -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_read_row_w_filter(table, temp_rows): - """ - Test read_row (single row helper) - """ - from google.cloud.bigtable.data import Row - from google.cloud.bigtable.data.row_filters import ApplyLabelFilter - - await temp_rows.add_row(b"row_key_1", value=b"value") - expected_label = "test-label" - label_filter = ApplyLabelFilter(expected_label) - row = await table.read_row(b"row_key_1", row_filter=label_filter) - assert isinstance(row, Row) - assert row.row_key == b"row_key_1" - assert row.cells[0].value == b"value" - assert row.cells[0].labels == [expected_label] - - -@pytest.mark.skipif( - bool(os.environ.get(BIGTABLE_EMULATOR)), - reason="emulator doesn't raise InvalidArgument", -) -@pytest.mark.usefixtures("table") -@pytest.mark.asyncio -async def test_row_exists(table, temp_rows): - from google.api_core import exceptions - - """Test row_exists with rows that exist and don't exist""" - assert await table.row_exists(b"row_key_1") is False - await temp_rows.add_row(b"row_key_1") - assert await table.row_exists(b"row_key_1") is True - assert await table.row_exists("row_key_1") is True - assert await table.row_exists(b"row_key_2") is False - assert await table.row_exists("row_key_2") is False - assert await table.row_exists("3") is False - await temp_rows.add_row(b"3") - assert await table.row_exists(b"3") is True - with pytest.raises(exceptions.InvalidArgument) as e: - await table.row_exists("") - assert "Row keys must be non-empty" in str(e) - - -@pytest.mark.usefixtures("table") -@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5) -@pytest.mark.parametrize( - "cell_value,filter_input,expect_match", - [ - (b"abc", b"abc", True), - (b"abc", "abc", True), - (b".", ".", True), - (".*", ".*", True), - (".*", b".*", True), - ("a", ".*", False), - (b".*", b".*", True), - (r"\a", r"\a", True), - (b"\xe2\x98\x83", "☃", True), - ("☃", "☃", True), - (r"\C☃", r"\C☃", True), - (1, 1, True), - (2, 1, False), - (68, 68, True), - ("D", 68, False), - (68, "D", False), - (-1, -1, True), - (2852126720, 2852126720, True), - (-1431655766, -1431655766, True), - (-1431655766, -1, False), - ], -) -@pytest.mark.asyncio -async def test_literal_value_filter( - table, temp_rows, cell_value, filter_input, expect_match -): - """ - Literal value filter does complex escaping on re2 strings. - Make sure inputs are properly interpreted by the server - """ - from google.cloud.bigtable.data.row_filters import LiteralValueFilter - from google.cloud.bigtable.data import ReadRowsQuery - - f = LiteralValueFilter(filter_input) - await temp_rows.add_row(b"row_key_1", value=cell_value) - query = ReadRowsQuery(row_filter=f) - row_list = await table.read_rows(query) - assert len(row_list) == bool( - expect_match - ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py new file mode 100644 index 000000000..d12936305 --- /dev/null +++ b/tests/system/data/test_system_async.py @@ -0,0 +1,992 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import asyncio +import uuid +import os +from google.api_core import retry +from google.api_core.exceptions import ClientError + +from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE +from google.cloud.environment_vars import BIGTABLE_EMULATOR + +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +from . import TEST_FAMILY, TEST_FAMILY_2 + + +@CrossSync.export_sync( + path="tests.system.data.test_system.TempRowBuilder", + add_mapping_for_name="TempRowBuilder", +) +class TempRowBuilderAsync: + """ + Used to add rows to a table for testing purposes. + """ + + def __init__(self, table): + self.rows = [] + self.table = table + + @CrossSync.convert + async def add_row( + self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" + ): + if isinstance(value, str): + value = value.encode("utf-8") + elif isinstance(value, int): + value = value.to_bytes(8, byteorder="big", signed=True) + request = { + "table_name": self.table.table_name, + "row_key": row_key, + "mutations": [ + { + "set_cell": { + "family_name": family, + "column_qualifier": qualifier, + "value": value, + } + } + ], + } + await self.table.client._gapic_client.mutate_row(request) + self.rows.append(row_key) + + @CrossSync.convert + async def delete_rows(self): + if self.rows: + request = { + "table_name": self.table.table_name, + "entries": [ + {"row_key": row, "mutations": [{"delete_from_row": {}}]} + for row in self.rows + ], + } + await self.table.client._gapic_client.mutate_rows(request) + + +@CrossSync.export_sync(path="tests.system.data.test_system.TestSystem") +class TestSystemAsync: + @CrossSync.convert + @CrossSync.pytest_fixture(scope="session") + async def client(self): + project = os.getenv("GOOGLE_CLOUD_PROJECT") or None + async with CrossSync.DataClient(project=project, pool_size=4) as client: + yield client + + @CrossSync.convert + @CrossSync.pytest_fixture(scope="session") + async def table(self, client, table_id, instance_id): + async with client.get_table( + instance_id, + table_id, + ) as table: + yield table + + @pytest.fixture(scope="session") + def event_loop(self): + loop = asyncio.get_event_loop() + yield loop + loop.stop() + loop.close() + + @pytest.fixture(scope="session") + def column_family_config(self): + """ + specify column families to create when creating a new test table + """ + from google.cloud.bigtable_admin_v2 import types + + return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} + + @pytest.fixture(scope="session") + def init_table_id(self): + """ + The table_id to use when creating a new test table + """ + return f"test-table-{uuid.uuid4().hex}" + + @pytest.fixture(scope="session") + def cluster_config(self, project_id): + """ + Configuration for the clusters to use when creating a new instance + """ + from google.cloud.bigtable_admin_v2 import types + + cluster = { + "test-cluster": types.Cluster( + location=f"projects/{project_id}/locations/us-central1-b", + serve_nodes=1, + ) + } + return cluster + + @CrossSync.convert + @pytest.mark.usefixtures("table") + async def _retrieve_cell_value(self, table, row_key): + """ + Helper to read an individual row + """ + from google.cloud.bigtable.data import ReadRowsQuery + + row_list = await table.read_rows(ReadRowsQuery(row_keys=row_key)) + assert len(row_list) == 1 + row = row_list[0] + cell = row.cells[0] + return cell.value + + @CrossSync.convert + async def _create_row_and_mutation( + self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" + ): + """ + Helper to create a new row, and a sample set_cell mutation to change its value + """ + from google.cloud.bigtable.data.mutations import SetCell + + row_key = uuid.uuid4().hex.encode() + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row( + row_key, family=family, qualifier=qualifier, value=start_value + ) + # ensure cell is initialized + assert (await self._retrieve_cell_value(table, row_key)) == start_value + + mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) + return row_key, mutation + + @CrossSync.convert + @CrossSync.pytest_fixture(scope="function") + async def temp_rows(self, table): + builder = CrossSync.TempRowBuilder(table) + yield builder + await builder.delete_rows() + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=10 + ) + @CrossSync.pytest + async def test_ping_and_warm_gapic(self, client, table): + """ + Simple ping rpc test + This test ensures channels are able to authenticate with backend + """ + request = {"name": table.instance_name} + await client._gapic_client.ping_and_warm(request) + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_ping_and_warm(self, client, table): + """ + Test ping and warm from handwritten client + """ + try: + channel = client.transport._grpc_channel.pool[0] + except Exception: + # for sync client + channel = client.transport._grpc_channel + results = await client._ping_and_warm_instances(channel) + assert len(results) == 1 + assert results[0] is None + + @CrossSync.pytest + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + async def test_mutation_set_cell(self, table, temp_rows): + """ + Ensure cells can be set properly + """ + row_key = b"bulk_mutate" + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + await table.mutate_row(row_key, mutation) + + # ensure cell is updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" + ) + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_sample_row_keys(self, client, table, temp_rows, column_split_config): + """ + Sample keys should return a single sample in small test tables + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + + results = await table.sample_row_keys() + assert len(results) == len(column_split_config) + 1 + # first keys should match the split config + for idx in range(len(column_split_config)): + assert results[idx][0] == column_split_config[idx] + assert isinstance(results[idx][1], int) + # last sample should be empty key + assert results[-1][0] == b"" + assert isinstance(results[-1][1], int) + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_bulk_mutations_set_cell(self, client, table, temp_rows): + """ + Ensure cells can be set properly + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + + await table.bulk_mutate_rows([bulk_mutation]) + + # ensure cell is updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + + @CrossSync.pytest + async def test_bulk_mutations_raise_exception(self, client, table): + """ + If an invalid mutation is passed, an exception should be raised + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + row_key = uuid.uuid4().hex.encode() + mutation = SetCell( + family="nonexistent", qualifier=b"test-qualifier", new_value=b"" + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + + with pytest.raises(MutationsExceptionGroup) as exc: + await table.bulk_mutate_rows([bulk_mutation]) + assert len(exc.value.exceptions) == 1 + entry_error = exc.value.exceptions[0] + assert isinstance(entry_error, FailedMutationEntryError) + assert entry_error.index == 0 + assert entry_error.entry == bulk_mutation + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_mutations_batcher_context_manager(self, client, table, temp_rows): + """ + test batcher with context manager. Should flush on exit + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + row_key2, mutation2 = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + async with table.mutations_batcher() as batcher: + await batcher.append(bulk_mutation) + await batcher.append(bulk_mutation2) + # ensure cell is updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + assert len(batcher._staged_entries) == 0 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_mutations_batcher_timer_flush(self, client, table, temp_rows): + """ + batch should occur after flush_interval seconds + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + flush_interval = 0.1 + async with table.mutations_batcher(flush_interval=flush_interval) as batcher: + await batcher.append(bulk_mutation) + await CrossSync.yield_to_event_loop() + assert len(batcher._staged_entries) == 1 + await CrossSync.sleep(flush_interval + 0.1) + assert len(batcher._staged_entries) == 0 + # ensure cell is updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_mutations_batcher_count_flush(self, client, table, temp_rows): + """ + batch should flush after flush_limit_mutation_count mutations + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + async with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + # should be noop; flush not scheduled + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # task should now be scheduled + assert len(batcher._flush_jobs) == 1 + # let flush complete + for future in list(batcher._flush_jobs): + await future + # for sync version: grab result + future.result() + assert len(batcher._staged_entries) == 0 + assert len(batcher._flush_jobs) == 0 + # ensure cells were updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + assert (await self._retrieve_cell_value(table, row_key2)) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): + """ + batch should flush after flush_limit_bytes bytes + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 + + async with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # task should now be scheduled + assert len(batcher._flush_jobs) == 1 + assert len(batcher._staged_entries) == 0 + # let flush complete + for future in list(batcher._flush_jobs): + await future + # for sync version: grab result + future.result() + # ensure cells were updated + assert (await self._retrieve_cell_value(table, row_key)) == new_value + assert (await self._retrieve_cell_value(table, row_key2)) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_mutations_batcher_no_flush(self, client, table, temp_rows): + """ + test with no flush requirements met + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + start_value = b"unchanged" + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = await self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + + size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 + async with table.mutations_batcher( + flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 + ) as batcher: + await batcher.append(bulk_mutation) + assert len(batcher._staged_entries) == 1 + await batcher.append(bulk_mutation2) + # flush not scheduled + assert len(batcher._flush_jobs) == 0 + await CrossSync.yield_to_event_loop() + assert len(batcher._staged_entries) == 2 + assert len(batcher._flush_jobs) == 0 + # ensure cells were not updated + assert (await self._retrieve_cell_value(table, row_key)) == start_value + assert (await self._retrieve_cell_value(table, row_key2)) == start_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,increment,expected", + [ + (0, 0, 0), + (0, 1, 1), + (0, -1, -1), + (1, 0, 1), + (0, -100, -100), + (0, 3000, 3000), + (10, 4, 14), + (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), + (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), + (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), + ], + ) + @CrossSync.pytest + async def test_read_modify_write_row_increment( + self, client, table, temp_rows, start, increment, expected + ): + """ + test read_modify_write_row + """ + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row( + row_key, value=start, family=family, qualifier=qualifier + ) + + rule = IncrementRule(family, qualifier, increment) + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert int(result[0]) == expected + # ensure that reading from server gives same value + assert (await self._retrieve_cell_value(table, row_key)) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,append,expected", + [ + (b"", b"", b""), + ("", "", b""), + (b"abc", b"123", b"abc123"), + (b"abc", "123", b"abc123"), + ("", b"1", b"1"), + (b"abc", "", b"abc"), + (b"hello", b"world", b"helloworld"), + ], + ) + @CrossSync.pytest + async def test_read_modify_write_row_append( + self, client, table, temp_rows, start, append, expected + ): + """ + test read_modify_write_row + """ + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + await temp_rows.add_row( + row_key, value=start, family=family, qualifier=qualifier + ) + + rule = AppendValueRule(family, qualifier, append) + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert result[0].value == expected + # ensure that reading from server gives same value + assert (await self._retrieve_cell_value(table, row_key)) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_read_modify_write_row_chained(self, client, table, temp_rows): + """ + test read_modify_write_row with multiple rules + """ + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + start_amount = 1 + increment_amount = 10 + await temp_rows.add_row( + row_key, value=start_amount, family=family, qualifier=qualifier + ) + rule = [ + IncrementRule(family, qualifier, increment_amount), + AppendValueRule(family, qualifier, "hello"), + AppendValueRule(family, qualifier, "world"), + AppendValueRule(family, qualifier, "!"), + ] + result = await table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert result[0].family == family + assert result[0].qualifier == qualifier + # result should be a bytes number string for the IncrementRules, followed by the AppendValueRule values + assert ( + result[0].value + == (start_amount + increment_amount).to_bytes(8, "big", signed=True) + + b"helloworld!" + ) + # ensure that reading from server gives same value + assert (await self._retrieve_cell_value(table, row_key)) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start_val,predicate_range,expected_result", + [ + (1, (0, 2), True), + (-1, (0, 2), False), + ], + ) + @CrossSync.pytest + async def test_check_and_mutate( + self, client, table, temp_rows, start_val, predicate_range, expected_result + ): + """ + test that check_and_mutate_row works applies the right mutations, and returns the right result + """ + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + + await temp_rows.add_row( + row_key, value=start_val, family=family, qualifier=qualifier + ) + + false_mutation_value = b"false-mutation-value" + false_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value + ) + true_mutation_value = b"true-mutation-value" + true_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value + ) + predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) + result = await table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + assert result == expected_result + # ensure cell is updated + expected_value = ( + true_mutation_value if expected_result else false_mutation_value + ) + assert (await self._retrieve_cell_value(table, row_key)) == expected_value + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_check_and_mutate_empty_request(self, client, table): + """ + check_and_mutate with no true or fale mutations should raise an error + """ + from google.api_core import exceptions + + with pytest.raises(exceptions.InvalidArgument) as e: + await table.check_and_mutate_row( + b"row_key", None, true_case_mutations=None, false_case_mutations=None + ) + assert "No mutations provided" in str(e.value) + + @pytest.mark.usefixtures("table") + @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_stream(self, table, temp_rows): + """ + Ensure that the read_rows_stream method works + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + + # full table scan + generator = await table.read_rows_stream({}) + first_row = await generator.__anext__() + second_row = await generator.__anext__() + assert first_row.row_key == b"row_key_1" + assert second_row.row_key == b"row_key_2" + with pytest.raises(CrossSync.StopIteration): + await generator.__anext__() + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows(self, table, temp_rows): + """ + Ensure that the read_rows method works + """ + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + # full table scan + row_list = await table.read_rows({}) + assert len(row_list) == 2 + assert row_list[0].row_key == b"row_key_1" + assert row_list[1].row_key == b"row_key_2" + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_sharded_simple(self, table, temp_rows): + """ + Test read rows sharded with two queries + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) + row_list = await table.read_rows_sharded([query1, query2]) + assert len(row_list) == 4 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"b" + assert row_list[3].row_key == b"d" + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_sharded_from_sample(self, table, temp_rows): + """ + Test end-to-end sharding + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.read_rows_query import RowRange + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + + table_shard_keys = await table.sample_row_keys() + query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) + shard_queries = query.shard(table_shard_keys) + row_list = await table.read_rows_sharded(shard_queries) + assert len(row_list) == 3 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"d" + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_sharded_filters_limits(self, table, temp_rows): + """ + Test read rows sharded with filters and limits + """ + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + + label_filter1 = ApplyLabelFilter("first") + label_filter2 = ApplyLabelFilter("second") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) + row_list = await table.read_rows_sharded([query1, query2]) + assert len(row_list) == 3 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"b" + assert row_list[2].row_key == b"d" + assert row_list[0][0].labels == ["first"] + assert row_list[1][0].labels == ["second"] + assert row_list[2][0].labels == ["second"] + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_range_query(self, table, temp_rows): + """ + Ensure that the read_rows method works + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data import RowRange + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # full table scan + query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) + row_list = await table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_single_key_query(self, table, temp_rows): + """ + Ensure that the read_rows method works with specified query + """ + from google.cloud.bigtable.data import ReadRowsQuery + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # retrieve specific keys + query = ReadRowsQuery(row_keys=[b"a", b"c"]) + row_list = await table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_read_rows_with_filter(self, table, temp_rows): + """ + ensure filters are applied + """ + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"a") + await temp_rows.add_row(b"b") + await temp_rows.add_row(b"c") + await temp_rows.add_row(b"d") + # retrieve keys with filter + expected_label = "test-label" + row_filter = ApplyLabelFilter(expected_label) + query = ReadRowsQuery(row_filter=row_filter) + row_list = await table.read_rows(query) + assert len(row_list) == 4 + for row in row_list: + assert row[0].labels == [expected_label] + + @pytest.mark.usefixtures("table") + @CrossSync.convert(replace_symbols={"__anext__": "__next__", "aclose": "close"}) + @CrossSync.pytest + async def test_read_rows_stream_close(self, table, temp_rows): + """ + Ensure that the read_rows_stream can be closed + """ + from google.cloud.bigtable.data import ReadRowsQuery + + await temp_rows.add_row(b"row_key_1") + await temp_rows.add_row(b"row_key_2") + # full table scan + query = ReadRowsQuery() + generator = await table.read_rows_stream(query) + # grab first row + first_row = await generator.__anext__() + assert first_row.row_key == b"row_key_1" + # close stream early + await generator.aclose() + with pytest.raises(CrossSync.StopIteration): + await generator.__anext__() + + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_read_row(self, table, temp_rows): + """ + Test read_row (single row helper) + """ + from google.cloud.bigtable.data import Row + + await temp_rows.add_row(b"row_key_1", value=b"value") + row = await table.read_row(b"row_key_1") + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_read_row_missing(self, table): + """ + Test read_row when row does not exist + """ + from google.api_core import exceptions + + row_key = "row_key_not_exist" + result = await table.read_row(row_key) + assert result is None + with pytest.raises(exceptions.InvalidArgument) as e: + await table.read_row("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_read_row_w_filter(self, table, temp_rows): + """ + Test read_row (single row helper) + """ + from google.cloud.bigtable.data import Row + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + await temp_rows.add_row(b"row_key_1", value=b"value") + expected_label = "test-label" + label_filter = ApplyLabelFilter(expected_label) + row = await table.read_row(b"row_key_1", row_filter=label_filter) + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + assert row.cells[0].labels == [expected_label] + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + @CrossSync.pytest + async def test_row_exists(self, table, temp_rows): + from google.api_core import exceptions + + """Test row_exists with rows that exist and don't exist""" + assert await table.row_exists(b"row_key_1") is False + await temp_rows.add_row(b"row_key_1") + assert await table.row_exists(b"row_key_1") is True + assert await table.row_exists("row_key_1") is True + assert await table.row_exists(b"row_key_2") is False + assert await table.row_exists("row_key_2") is False + assert await table.row_exists("3") is False + await temp_rows.add_row(b"3") + assert await table.row_exists(b"3") is True + with pytest.raises(exceptions.InvalidArgument) as e: + await table.row_exists("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @pytest.mark.parametrize( + "cell_value,filter_input,expect_match", + [ + (b"abc", b"abc", True), + (b"abc", "abc", True), + (b".", ".", True), + (".*", ".*", True), + (".*", b".*", True), + ("a", ".*", False), + (b".*", b".*", True), + (r"\a", r"\a", True), + (b"\xe2\x98\x83", "☃", True), + ("☃", "☃", True), + (r"\C☃", r"\C☃", True), + (1, 1, True), + (2, 1, False), + (68, 68, True), + ("D", 68, False), + (68, "D", False), + (-1, -1, True), + (2852126720, 2852126720, True), + (-1431655766, -1431655766, True), + (-1431655766, -1, False), + ], + ) + @CrossSync.pytest + async def test_literal_value_filter( + self, table, temp_rows, cell_value, filter_input, expect_match + ): + """ + Literal value filter does complex escaping on re2 strings. + Make sure inputs are properly interpreted by the server + """ + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + from google.cloud.bigtable.data import ReadRowsQuery + + f = LiteralValueFilter(filter_input) + await temp_rows.add_row(b"row_key_1", value=cell_value) + query = ReadRowsQuery(row_filter=f) + row_list = await table.read_rows(query) + assert len(row_list) == bool( + expect_match + ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" diff --git a/tests/unit/data/_async/__init__.py b/tests/unit/data/_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index e03028c45..a307a7008 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -16,42 +16,42 @@ from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 -import google.api_core.exceptions as core_exceptions +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import Forbidden + +from google.cloud.bigtable.data._sync.cross_sync import CrossSync # try/except added for compatibility with python < 3.8 try: from unittest import mock - from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER import mock # type: ignore - from mock import AsyncMock # type: ignore - - -def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation +@CrossSync.export_sync( + path="tests.unit.data._sync.test__mutate_rows.TestMutateRowsOperation", +) class TestMutateRowsOperation: def _target_class(self): - from google.cloud.bigtable.data._async._mutate_rows import ( - _MutateRowsOperationAsync, - ) - - return _MutateRowsOperationAsync + return CrossSync._MutateRowsOperation def _make_one(self, *args, **kwargs): if not args: kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) - kwargs["table"] = kwargs.pop("table", AsyncMock()) + kwargs["table"] = kwargs.pop("table", CrossSync.Mock()) kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) return self._target_class()(*args, **kwargs) + def _make_mutation(self, count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + @CrossSync.convert async def _mock_stream(self, mutation_list, error_dict): for idx, entry in enumerate(mutation_list): code = error_dict.get(idx, 0) @@ -64,7 +64,7 @@ async def _mock_stream(self, mutation_list, error_dict): ) def _make_mock_gapic(self, mutation_list, error_dict=None): - mock_fn = AsyncMock() + mock_fn = CrossSync.Mock() if error_dict is None: error_dict = {} mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( @@ -83,7 +83,7 @@ def test_ctor(self): client = mock.Mock() table = mock.Mock() - entries = [_make_mutation(), _make_mutation()] + entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 attempt_timeout = 0.01 retryable_exceptions = () @@ -136,17 +136,14 @@ def test_ctor_too_many_entries(self): client = mock.Mock() table = mock.Mock() - entries = [_make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + entries = [self._make_mutation()] * (_MUTATE_ROWS_REQUEST_MUTATION_LIMIT + 1) operation_timeout = 0.05 attempt_timeout = 0.01 - # no errors if at limit - self._make_one(client, table, entries, operation_timeout, attempt_timeout) - # raise error after crossing with pytest.raises(ValueError) as e: self._make_one( client, table, - entries + [_make_mutation()], + entries, operation_timeout, attempt_timeout, ) @@ -155,18 +152,18 @@ def test_ctor_too_many_entries(self): ) assert "Found 100001" in str(e.value) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_rows_operation(self): """ Test successful case of mutate_rows_operation """ client = mock.Mock() table = mock.Mock() - entries = [_make_mutation(), _make_mutation()] + entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 cls = self._target_class() with mock.patch( - f"{cls.__module__}.{cls.__name__}._run_attempt", AsyncMock() + f"{cls.__module__}.{cls.__name__}._run_attempt", CrossSync.Mock() ) as attempt_mock: instance = self._make_one( client, table, entries, operation_timeout, operation_timeout @@ -174,17 +171,15 @@ async def test_mutate_rows_operation(self): await instance.start() assert attempt_mock.call_count == 1 - @pytest.mark.parametrize( - "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] - ) - @pytest.mark.asyncio + @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) + @CrossSync.pytest async def test_mutate_rows_attempt_exception(self, exc_type): """ exceptions raised from attempt should be raised in MutationsExceptionGroup """ - client = AsyncMock() + client = CrossSync.Mock() table = mock.Mock() - entries = [_make_mutation(), _make_mutation()] + entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 expected_exception = exc_type("test") client.mutate_rows.side_effect = expected_exception @@ -202,10 +197,8 @@ async def test_mutate_rows_attempt_exception(self, exc_type): assert len(instance.errors) == 2 assert len(instance.remaining_indices) == 0 - @pytest.mark.parametrize( - "exc_type", [RuntimeError, ZeroDivisionError, core_exceptions.Forbidden] - ) - @pytest.mark.asyncio + @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) + @CrossSync.pytest async def test_mutate_rows_exception(self, exc_type): """ exceptions raised from retryable should be raised in MutationsExceptionGroup @@ -215,13 +208,13 @@ async def test_mutate_rows_exception(self, exc_type): client = mock.Mock() table = mock.Mock() - entries = [_make_mutation(), _make_mutation()] + entries = [self._make_mutation(), self._make_mutation()] operation_timeout = 0.05 expected_cause = exc_type("abort") with mock.patch.object( self._target_class(), "_run_attempt", - AsyncMock(), + CrossSync.Mock(), ) as attempt_mock: attempt_mock.side_effect = expected_cause found_exc = None @@ -241,27 +234,24 @@ async def test_mutate_rows_exception(self, exc_type): @pytest.mark.parametrize( "exc_type", - [core_exceptions.DeadlineExceeded, RuntimeError], + [DeadlineExceeded, RuntimeError], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): """ If an exception fails but eventually passes, it should not raise an exception """ - from google.cloud.bigtable.data._async._mutate_rows import ( - _MutateRowsOperationAsync, - ) client = mock.Mock() table = mock.Mock() - entries = [_make_mutation()] + entries = [self._make_mutation()] operation_timeout = 1 expected_cause = exc_type("retry") num_retries = 2 with mock.patch.object( - _MutateRowsOperationAsync, + self._target_class(), "_run_attempt", - AsyncMock(), + CrossSync.Mock(), ) as attempt_mock: attempt_mock.side_effect = [expected_cause] * num_retries + [None] instance = self._make_one( @@ -275,7 +265,7 @@ async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): await instance.start() assert attempt_mock.call_count == num_retries + 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_rows_incomplete_ignored(self): """ MutateRowsIncomplete exceptions should not be added to error list @@ -286,12 +276,12 @@ async def test_mutate_rows_incomplete_ignored(self): client = mock.Mock() table = mock.Mock() - entries = [_make_mutation()] + entries = [self._make_mutation()] operation_timeout = 0.05 with mock.patch.object( self._target_class(), "_run_attempt", - AsyncMock(), + CrossSync.Mock(), ) as attempt_mock: attempt_mock.side_effect = _MutateRowsIncomplete("ignored") found_exc = None @@ -306,10 +296,10 @@ async def test_mutate_rows_incomplete_ignored(self): assert len(found_exc.exceptions) == 1 assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) - @pytest.mark.asyncio + @CrossSync.pytest async def test_run_attempt_single_entry_success(self): """Test mutating a single entry""" - mutation = _make_mutation() + mutation = self._make_mutation() expected_timeout = 1.3 mock_gapic_fn = self._make_mock_gapic({0: mutation}) instance = self._make_one( @@ -324,7 +314,7 @@ async def test_run_attempt_single_entry_success(self): assert kwargs["timeout"] == expected_timeout assert kwargs["entries"] == [mutation._to_pb()] - @pytest.mark.asyncio + @CrossSync.pytest async def test_run_attempt_empty_request(self): """Calling with no mutations should result in no API calls""" mock_gapic_fn = self._make_mock_gapic([]) @@ -334,14 +324,14 @@ async def test_run_attempt_empty_request(self): await instance._run_attempt() assert mock_gapic_fn.call_count == 0 - @pytest.mark.asyncio + @CrossSync.pytest async def test_run_attempt_partial_success_retryable(self): """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete - success_mutation = _make_mutation() - success_mutation_2 = _make_mutation() - failure_mutation = _make_mutation() + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() mutations = [success_mutation, failure_mutation, success_mutation_2] mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) instance = self._make_one( @@ -357,12 +347,12 @@ async def test_run_attempt_partial_success_retryable(self): assert instance.errors[1][0].grpc_status_code == 300 assert 2 not in instance.errors - @pytest.mark.asyncio + @CrossSync.pytest async def test_run_attempt_partial_success_non_retryable(self): """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" - success_mutation = _make_mutation() - success_mutation_2 = _make_mutation() - failure_mutation = _make_mutation() + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() mutations = [success_mutation, failure_mutation, success_mutation_2] mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) instance = self._make_one( diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 2bf8688fd..896c17879 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -13,23 +13,19 @@ import pytest -from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync +from google.cloud.bigtable.data._sync.cross_sync import CrossSync # try/except added for compatibility with python < 3.8 try: from unittest import mock - from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER import mock # type: ignore - from mock import AsyncMock # type: ignore # noqa F401 -TEST_FAMILY = "family_name" -TEST_QUALIFIER = b"qualifier" -TEST_TIMESTAMP = 123456789 -TEST_LABELS = ["label1", "label2"] - -class TestReadRowsOperation: +@CrossSync.export_sync( + path="tests.unit.data._sync.test__read_rows.TestReadRowsOperation", +) +class TestReadRowsOperationAsync: """ Tests helper functions in the ReadRowsOperation class in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt @@ -37,10 +33,9 @@ class TestReadRowsOperation: """ @staticmethod + @CrossSync.convert def _get_target_class(): - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - - return _ReadRowsOperationAsync + return CrossSync._ReadRowsOperation def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) @@ -60,8 +55,9 @@ def test_ctor(self): expected_operation_timeout = 42 expected_request_timeout = 44 time_gen_mock = mock.Mock() + subpath = "_async" if CrossSync.is_async else "_sync" with mock.patch( - "google.cloud.bigtable.data._async._read_rows._attempt_timeout_generator", + f"google.cloud.bigtable.data.{subpath}._read_rows._attempt_timeout_generator", time_gen_mock, ): instance = self._make_one( @@ -242,7 +238,7 @@ def test_revise_to_empty_rowset(self): (4, 2, 2), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_revise_limit(self, start_limit, emit_num, expected_limit): """ revise_limit should revise the request's limit field @@ -283,7 +279,7 @@ async def mock_stream(): assert instance._remaining_count == expected_limit @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_revise_limit_over_limit(self, start_limit, emit_num): """ Should raise runtime error if we get in state where emit_num > start_num @@ -322,7 +318,11 @@ async def mock_stream(): pass assert "emit count exceeds row limit" in str(e.value) - @pytest.mark.asyncio + @CrossSync.pytest + @CrossSync.convert( + sync_name="test_close", + replace_symbols={"aclose": "close", "__anext__": "__next__"}, + ) async def test_aclose(self): """ should be able to close a stream safely with aclose. @@ -334,7 +334,7 @@ async def mock_stream(): yield 1 with mock.patch.object( - _ReadRowsOperationAsync, "_read_rows_attempt" + self._get_target_class(), "_read_rows_attempt" ) as mock_attempt: instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) wrapped_gen = mock_stream() @@ -343,20 +343,20 @@ async def mock_stream(): # read one row await gen.__anext__() await gen.aclose() - with pytest.raises(StopAsyncIteration): + with pytest.raises(CrossSync.StopIteration): await gen.__anext__() # try calling a second time await gen.aclose() # ensure close was propagated to wrapped generator - with pytest.raises(StopAsyncIteration): + with pytest.raises(CrossSync.StopIteration): await wrapped_gen.__anext__() - @pytest.mark.asyncio + @CrossSync.pytest + @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) async def test_retryable_ignore_repeated_rows(self): """ Duplicate rows should cause an invalid chunk error """ - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable_v2.types import ReadRowsResponse @@ -381,37 +381,10 @@ async def mock_stream(): instance = mock.Mock() instance._last_yielded_row_key = None instance._remaining_count = None - stream = _ReadRowsOperationAsync.chunk_stream(instance, mock_awaitable_stream()) + stream = self._get_target_class().chunk_stream( + instance, mock_awaitable_stream() + ) await stream.__anext__() with pytest.raises(InvalidChunk) as exc: await stream.__anext__() assert "row keys should be strictly increasing" in str(exc.value) - - -class MockStream(_ReadRowsOperationAsync): - """ - Mock a _ReadRowsOperationAsync stream for testing - """ - - def __init__(self, items=None, errors=None, operation_timeout=None): - self.transient_errors = errors - self.operation_timeout = operation_timeout - self.next_idx = 0 - if items is None: - items = list(range(10)) - self.items = items - - def __aiter__(self): - return self - - async def __anext__(self): - if self.next_idx >= len(self.items): - raise StopAsyncIteration - item = self.items[self.next_idx] - self.next_idx += 1 - if isinstance(item, Exception): - raise item - return item - - async def aclose(self): - pass diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 9ebc403ce..b51987c5d 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -32,57 +32,62 @@ from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + # try/except added for compatibility with python < 3.8 try: from unittest import mock - from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER import mock # type: ignore - from mock import AsyncMock # type: ignore - -VENEER_HEADER_REGEX = re.compile( - r"gapic\/[0-9]+\.[\w.-]+ gax\/[0-9]+\.[\w.-]+ gccl\/[0-9]+\.[\w.-]+-data-async gl-python\/[0-9]+\.[\w.-]+ grpc\/[0-9]+\.[\w.-]+" -) +if CrossSync.is_async: + from google.api_core import grpc_helpers_async + from google.cloud.bigtable.data._async.client import TableAsync -def _make_client(*args, use_emulator=True, **kwargs): - import os - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - env_mask = {} - # by default, use emulator mode to avoid auth issues in CI - # emulator mode must be disabled by tests that check channel pooling/refresh background tasks - if use_emulator: - env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" - else: - # set some default values - kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) - kwargs["project"] = kwargs.get("project", "project-id") - with mock.patch.dict(os.environ, env_mask): - return BigtableDataClientAsync(*args, **kwargs) + CrossSync.add_mapping("grpc_helpers", grpc_helpers_async) +@CrossSync.export_sync( + path="tests.unit.data._sync.test_client.TestBigtableDataClient", + add_mapping_for_name="TestBigtableDataClient", +) class TestBigtableDataClientAsync: - def _get_target_class(self): - from google.cloud.bigtable.data._async.client import BigtableDataClientAsync - - return BigtableDataClientAsync - - def _make_one(self, *args, **kwargs): - return _make_client(*args, **kwargs) - - @pytest.mark.asyncio + @staticmethod + @CrossSync.convert + def _get_target_class(): + return CrossSync.DataClient + + @classmethod + def _make_client(cls, *args, use_emulator=True, **kwargs): + import os + + env_mask = {} + # by default, use emulator mode to avoid auth issues in CI + # emulator mode must be disabled by tests that check channel pooling/refresh background tasks + if use_emulator: + env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" + import warnings + + warnings.filterwarnings("ignore", category=RuntimeWarning) + else: + # set some default values + kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) + kwargs["project"] = kwargs.get("project", "project-id") + with mock.patch.dict(os.environ, env_mask): + return cls._get_target_class()(*args, **kwargs) + + @CrossSync.pytest async def test_ctor(self): expected_project = "project-id" expected_pool_size = 11 expected_credentials = AnonymousCredentials() - client = self._make_one( + client = self._make_client( project="project-id", pool_size=expected_pool_size, credentials=expected_credentials, use_emulator=False, ) - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert client.project == expected_project assert len(client.transport._grpc_channel._pool) == expected_pool_size assert not client._active_instances @@ -90,28 +95,29 @@ async def test_ctor(self): assert client.transport._credentials == expected_credentials await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor_super_inits(self): - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) from google.cloud.client import ClientWithProject from google.api_core import client_options as client_options_lib + from google.cloud.bigtable import __version__ as bigtable_version project = "project-id" pool_size = 11 credentials = AnonymousCredentials() client_options = {"api_endpoint": "foo.bar:1234"} options_parsed = client_options_lib.from_dict(client_options) - transport_str = f"pooled_grpc_asyncio_{pool_size}" - with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: + asyncio_portion = "-async" if CrossSync.is_async else "" + transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" + with mock.patch.object( + CrossSync.GapicClient, "__init__" + ) as bigtable_client_init: bigtable_client_init.return_value = None with mock.patch.object( ClientWithProject, "__init__" ) as client_project_init: client_project_init.return_value = None try: - self._make_one( + self._make_client( project=project, pool_size=pool_size, credentials=credentials, @@ -133,17 +139,16 @@ async def test_ctor_super_inits(self): assert kwargs["credentials"] == credentials assert kwargs["client_options"] == options_parsed - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor_dict_options(self): - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) from google.api_core.client_options import ClientOptions client_options = {"api_endpoint": "foo.bar:1234"} - with mock.patch.object(BigtableAsyncClient, "__init__") as bigtable_client_init: + with mock.patch.object( + CrossSync.GapicClient, "__init__" + ) as bigtable_client_init: try: - self._make_one(client_options=client_options) + self._make_client(client_options=client_options) except TypeError: pass bigtable_client_init.assert_called_once() @@ -154,17 +159,29 @@ async def test_ctor_dict_options(self): with mock.patch.object( self._get_target_class(), "_start_background_channel_refresh" ) as start_background_refresh: - client = self._make_one(client_options=client_options, use_emulator=False) + client = self._make_client( + client_options=client_options, use_emulator=False + ) start_background_refresh.assert_called_once() await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_veneer_grpc_headers(self): + client_component = "data-async" if CrossSync.is_async else "data" + VENEER_HEADER_REGEX = re.compile( + r"gapic\/[0-9]+\.[\w.-]+ gax\/[0-9]+\.[\w.-]+ gccl\/[0-9]+\.[\w.-]+-" + + client_component + + r" gl-python\/[0-9]+\.[\w.-]+ grpc\/[0-9]+\.[\w.-]+" + ) + # client_info should be populated with headers to # detect as a veneer client - patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") + if CrossSync.is_async: + patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") + else: + patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") with patch as gapic_mock: - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") wrapped_call_list = gapic_mock.call_args_list assert len(wrapped_call_list) > 0 # each wrapped call should have veneer headers @@ -179,33 +196,27 @@ async def test_veneer_grpc_headers(self): ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_channel_pool_creation(self): pool_size = 14 - with mock.patch( - "google.api_core.grpc_helpers_async.create_channel" + with mock.patch.object( + CrossSync.grpc_helpers, "create_channel", CrossSync.Mock() ) as create_channel: - create_channel.return_value = AsyncMock() - client = self._make_one(project="project-id", pool_size=pool_size) + client = self._make_client(project="project-id", pool_size=pool_size) assert create_channel.call_count == pool_size await client.close() # channels should be unique - client = self._make_one(project="project-id", pool_size=pool_size) + client = self._make_client(project="project-id", pool_size=pool_size) pool_list = list(client.transport._grpc_channel._pool) pool_set = set(client.transport._grpc_channel._pool) assert len(pool_list) == len(pool_set) await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_channel_pool_rotation(self): - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel, - ) - pool_size = 7 - - with mock.patch.object(PooledChannel, "next_channel") as next_channel: - client = self._make_one(project="project-id", pool_size=pool_size) + with mock.patch.object(CrossSync.PooledChannel, "next_channel") as next_channel: + client = self._make_client(project="project-id", pool_size=pool_size) assert len(client.transport._grpc_channel._pool) == pool_size next_channel.reset_mock() with mock.patch.object( @@ -224,25 +235,30 @@ async def test_channel_pool_rotation(self): unary_unary.reset_mock() await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_channel_pool_replace(self): - with mock.patch.object(asyncio, "sleep"): + import time + + sleep_module = asyncio if CrossSync.is_async else time + with mock.patch.object(sleep_module, "sleep"): pool_size = 7 - client = self._make_one(project="project-id", pool_size=pool_size) + client = self._make_client(project="project-id", pool_size=pool_size) for replace_idx in range(pool_size): start_pool = [ channel for channel in client.transport._grpc_channel._pool ] grace_period = 9 with mock.patch.object( - type(client.transport._grpc_channel._pool[0]), "close" + type(client.transport._grpc_channel._pool[-1]), "close" ) as close: - new_channel = grpc.aio.insecure_channel("localhost:8080") + new_channel = client.transport.create_channel() await client.transport.replace_channel( replace_idx, grace=grace_period, new_channel=new_channel ) - close.assert_called_once_with(grace=grace_period) - close.assert_awaited_once() + close.assert_called_once() + if CrossSync.is_async: + close.assert_called_once_with(grace=grace_period) + close.assert_awaited_once() assert client.transport._grpc_channel._pool[replace_idx] == new_channel for i in range(pool_size): if i != replace_idx: @@ -251,50 +267,59 @@ async def test_channel_pool_replace(self): assert client.transport._grpc_channel._pool[i] != start_pool[i] await client.close() + @CrossSync.drop_method @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test__start_background_channel_refresh_sync(self): # should raise RuntimeError if called in a sync context - client = self._make_one(project="project-id", use_emulator=False) + client = self._make_client(project="project-id", use_emulator=False) with pytest.raises(RuntimeError): client._start_background_channel_refresh() - @pytest.mark.asyncio + @CrossSync.pytest async def test__start_background_channel_refresh_tasks_exist(self): # if tasks exist, should do nothing - client = self._make_one(project="project-id", use_emulator=False) + client = self._make_client(project="project-id", use_emulator=False) assert len(client._channel_refresh_tasks) > 0 with mock.patch.object(asyncio, "create_task") as create_task: client._start_background_channel_refresh() create_task.assert_not_called() await client.close() - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize("pool_size", [1, 3, 7]) async def test__start_background_channel_refresh(self, pool_size): + import concurrent.futures + # should create background tasks for each channel - client = self._make_one( - project="project-id", pool_size=pool_size, use_emulator=False - ) - ping_and_warm = AsyncMock() - client._ping_and_warm_instances = ping_and_warm - client._start_background_channel_refresh() - assert len(client._channel_refresh_tasks) == pool_size - for task in client._channel_refresh_tasks: - assert isinstance(task, asyncio.Task) - await asyncio.sleep(0.1) - assert ping_and_warm.call_count == pool_size - for channel in client.transport._grpc_channel._pool: - ping_and_warm.assert_any_call(channel) + with mock.patch.object( + self._get_target_class(), "_ping_and_warm_instances", CrossSync.Mock() + ) as ping_and_warm: + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + client._start_background_channel_refresh() + assert len(client._channel_refresh_tasks) == pool_size + for task in client._channel_refresh_tasks: + if CrossSync.is_async: + assert isinstance(task, asyncio.Task) + else: + assert isinstance(task, concurrent.futures.Future) + if CrossSync.is_async: + await asyncio.sleep(0.1) + assert ping_and_warm.call_count == pool_size + for channel in client.transport._grpc_channel._pool: + ping_and_warm.assert_any_call(channel) await client.close() - @pytest.mark.asyncio + @CrossSync.drop_method + @CrossSync.pytest @pytest.mark.skipif( sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" ) async def test__start_background_channel_refresh_tasks_names(self): # if tasks exist, should do nothing pool_size = 3 - client = self._make_one( + client = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False ) for i in range(pool_size): @@ -303,15 +328,22 @@ async def test__start_background_channel_refresh_tasks_names(self): assert "BigtableDataClientAsync channel refresh " in name await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__ping_and_warm_instances(self): """ test ping and warm with mocked asyncio.gather """ client_mock = mock.Mock() - with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: - # simulate gather by returning the same number of items as passed in - gather.side_effect = lambda *args, **kwargs: [None for _ in args] + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + with mock.patch.object( + CrossSync, "gather_partials", CrossSync.Mock() + ) as gather: + # gather_partials is expected to call the function passed, and return the result + gather.side_effect = lambda partials, **kwargs: [None for _ in partials] channel = mock.Mock() # test with no instances client_mock._active_instances = [] @@ -319,10 +351,8 @@ async def test__ping_and_warm_instances(self): client_mock, channel ) assert len(result) == 0 - gather.assert_called_once() - gather.assert_awaited_once() - assert not gather.call_args.args - assert gather.call_args.kwargs == {"return_exceptions": True} + assert gather.call_args.kwargs["return_exceptions"] is True + assert gather.call_args.kwargs["sync_executor"] == client_mock._executor # test with instances client_mock._active_instances = [ (mock.Mock(), mock.Mock(), mock.Mock()) @@ -334,8 +364,11 @@ async def test__ping_and_warm_instances(self): ) assert len(result) == 4 gather.assert_called_once() - gather.assert_awaited_once() - assert len(gather.call_args.args) == 4 + # expect one partial for each instance + partial_list = gather.call_args.args[0] + assert len(partial_list) == 4 + if CrossSync.is_async: + gather.assert_awaited_once() # check grpc call arguments grpc_call_args = channel.unary_unary().call_args_list for idx, (_, kwargs) in enumerate(grpc_call_args): @@ -355,15 +388,21 @@ async def test__ping_and_warm_instances(self): == f"name={expected_instance}&app_profile_id={expected_app_profile}" ) - @pytest.mark.asyncio + @CrossSync.pytest async def test__ping_and_warm_single_instance(self): """ should be able to call ping and warm with single instance """ client_mock = mock.Mock() - with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: - # simulate gather by returning the same number of items as passed in - gather.side_effect = lambda *args, **kwargs: [None for _ in args] + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + with mock.patch.object( + CrossSync, "gather_partials", CrossSync.Mock() + ) as gather: + gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] channel = mock.Mock() # test with large set of instances client_mock._active_instances = [mock.Mock()] * 100 @@ -387,7 +426,7 @@ async def test__ping_and_warm_single_instance(self): metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" ) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "refresh_interval, wait_time, expected_sleep", [ @@ -405,41 +444,46 @@ async def test__manage_channel_first_sleep( # first sleep time should be `refresh_interval` seconds after client init import time - with mock.patch.object(time, "monotonic") as time: - time.return_value = 0 - with mock.patch.object(asyncio, "sleep") as sleep: + with mock.patch.object(time, "monotonic") as monotonic: + monotonic.return_value = 0 + with mock.patch.object(CrossSync, "event_wait") as sleep: sleep.side_effect = asyncio.CancelledError try: - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") client._channel_init_time = -wait_time await client._manage_channel(0, refresh_interval, refresh_interval) except asyncio.CancelledError: pass sleep.assert_called_once() - call_time = sleep.call_args[0][0] + call_time = sleep.call_args[0][1] assert ( abs(call_time - expected_sleep) < 0.1 ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__manage_channel_ping_and_warm(self): """ _manage channel should call ping and warm internally """ import time + import threading client_mock = mock.Mock() + client_mock._is_closed.is_set.return_value = False client_mock._channel_init_time = time.monotonic() channel_list = [mock.Mock(), mock.Mock()] client_mock.transport.channels = channel_list new_channel = mock.Mock() client_mock.transport.grpc_channel._create_channel.return_value = new_channel # should ping an warm all new channels, and old channels if sleeping - with mock.patch.object(asyncio, "sleep"): + sleep_tuple = ( + (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple): # stop process after replace_channel is called client_mock.transport.replace_channel.side_effect = asyncio.CancelledError - ping_and_warm = client_mock._ping_and_warm_instances = AsyncMock() + ping_and_warm = client_mock._ping_and_warm_instances = CrossSync.Mock() # should ping and warm old channel then new if sleep > 0 try: channel_idx = 1 @@ -466,7 +510,7 @@ async def test__manage_channel_ping_and_warm(self): pass ping_and_warm.assert_called_once_with(new_channel) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "refresh_interval, num_cycles, expected_sleep", [ @@ -481,43 +525,59 @@ async def test__manage_channel_sleeps( # make sure that sleeps work as expected import time import random + import threading channel_idx = 1 with mock.patch.object(random, "uniform") as uniform: uniform.side_effect = lambda min_, max_: min_ - with mock.patch.object(time, "time") as time: - time.return_value = 0 - with mock.patch.object(asyncio, "sleep") as sleep: + with mock.patch.object(time, "time") as time_mock: + time_mock.return_value = 0 + sleep_tuple = ( + (asyncio, "sleep") + if CrossSync.is_async + else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = [None for i in range(num_cycles - 1)] + [ asyncio.CancelledError ] - try: - client = self._make_one(project="project-id") - if refresh_interval is not None: - await client._manage_channel( - channel_idx, refresh_interval, refresh_interval - ) - else: - await client._manage_channel(channel_idx) - except asyncio.CancelledError: - pass + client = self._make_client(project="project-id") + with mock.patch.object(client.transport, "replace_channel"): + try: + if refresh_interval is not None: + await client._manage_channel( + channel_idx, refresh_interval, refresh_interval + ) + else: + await client._manage_channel(channel_idx) + except asyncio.CancelledError: + pass assert sleep.call_count == num_cycles - total_sleep = sum([call[0][0] for call in sleep.call_args_list]) + if CrossSync.is_async: + total_sleep = sum([call[0][0] for call in sleep.call_args_list]) + else: + total_sleep = sum( + [call[1]["timeout"] for call in sleep.call_args_list] + ) assert ( abs(total_sleep - expected_sleep) < 0.1 ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__manage_channel_random(self): import random + import threading - with mock.patch.object(asyncio, "sleep") as sleep: + sleep_tuple = ( + (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: with mock.patch.object(random, "uniform") as uniform: uniform.return_value = 0 try: uniform.side_effect = asyncio.CancelledError - client = self._make_one(project="project-id", pool_size=1) + client = self._make_client(project="project-id", pool_size=1) except asyncio.CancelledError: uniform.side_effect = None uniform.reset_mock() @@ -527,41 +587,48 @@ async def test__manage_channel_random(self): uniform.side_effect = lambda min_, max_: min_ sleep.side_effect = [None, None, asyncio.CancelledError] try: - await client._manage_channel(0, min_val, max_val) + with mock.patch.object(client.transport, "replace_channel"): + await client._manage_channel(0, min_val, max_val) except asyncio.CancelledError: pass - assert uniform.call_count == 2 + assert uniform.call_count == 3 uniform_args = [call[0] for call in uniform.call_args_list] for found_min, found_max in uniform_args: assert found_min == min_val assert found_max == max_val - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) async def test__manage_channel_refresh(self, num_cycles): # make sure that channels are properly refreshed - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - ) - from google.api_core import grpc_helpers_async + import threading expected_grace = 9 expected_refresh = 0.5 channel_idx = 1 - new_channel = grpc.aio.insecure_channel("localhost:8080") + grpc_lib = grpc.aio if CrossSync.is_async else grpc + new_channel = grpc_lib.insecure_channel("localhost:8080") with mock.patch.object( - PooledBigtableGrpcAsyncIOTransport, "replace_channel" + CrossSync.PooledTransport, "replace_channel" ) as replace_channel: - with mock.patch.object(asyncio, "sleep") as sleep: + sleep_tuple = ( + (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: sleep.side_effect = [None for i in range(num_cycles)] + [ asyncio.CancelledError ] with mock.patch.object( - grpc_helpers_async, "create_channel" + CrossSync.grpc_helpers, "create_channel" ) as create_channel: create_channel.return_value = new_channel - client = self._make_one(project="project-id", use_emulator=False) + with mock.patch.object( + self._get_target_class(), "_start_background_channel_refresh" + ): + client = self._make_client( + project="project-id", use_emulator=False + ) create_channel.reset_mock() try: await client._manage_channel( @@ -582,7 +649,7 @@ async def test__manage_channel_refresh(self, num_cycles): assert kwargs["new_channel"] == new_channel await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__register_instance(self): """ test instance registration @@ -600,7 +667,7 @@ async def test__register_instance(self): ) mock_channels = [mock.Mock() for i in range(5)] client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = AsyncMock() + client_mock._ping_and_warm_instances = CrossSync.Mock() table_mock = mock.Mock() await self._get_target_class()._register_instance( client_mock, "instance-1", table_mock @@ -653,7 +720,7 @@ async def test__register_instance(self): ] ) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "insert_instances,expected_active,expected_owner_keys", [ @@ -686,7 +753,7 @@ async def test__register_instance_state( ) mock_channels = [mock.Mock() for i in range(5)] client_mock.transport.channels = mock_channels - client_mock._ping_and_warm_instances = AsyncMock() + client_mock._ping_and_warm_instances = CrossSync.Mock() table_mock = mock.Mock() # register instances for instance, table, profile in insert_instances: @@ -712,9 +779,9 @@ async def test__register_instance_state( ] ) - @pytest.mark.asyncio + @CrossSync.pytest async def test__remove_instance_registration(self): - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") table = mock.Mock() await client._register_instance("instance-1", table) await client._register_instance("instance-2", table) @@ -743,16 +810,16 @@ async def test__remove_instance_registration(self): assert len(client._active_instances) == 1 await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test__multiple_table_registration(self): """ registering with multiple tables with the same key should add multiple owners to instance_owners, but only keep one copy of shared key in active_instances """ - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - async with self._make_one(project="project-id") as client: + async with self._make_client(project="project-id") as client: async with client.get_table("instance_1", "table_1") as table_1: instance_1_path = client._gapic_client.instance_path( client.project, "instance_1" @@ -765,12 +832,20 @@ async def test__multiple_table_registration(self): assert id(table_1) in client._instance_owners[instance_1_key] # duplicate table should register in instance_owners under same key async with client.get_table("instance_1", "table_1") as table_2: + assert table_2._register_instance_future is not None + if not CrossSync.is_async: + # give the background task time to run + table_2._register_instance_future.result() assert len(client._instance_owners[instance_1_key]) == 2 assert len(client._active_instances) == 1 assert id(table_1) in client._instance_owners[instance_1_key] assert id(table_2) in client._instance_owners[instance_1_key] # unique table should register in instance_owners and active_instances async with client.get_table("instance_1", "table_3") as table_3: + assert table_3._register_instance_future is not None + if not CrossSync.is_async: + # give the background task time to run + table_3._register_instance_future.result() instance_3_path = client._gapic_client.instance_path( client.project, "instance_1" ) @@ -792,17 +867,25 @@ async def test__multiple_table_registration(self): assert instance_1_key not in client._active_instances assert len(client._instance_owners[instance_1_key]) == 0 - @pytest.mark.asyncio + @CrossSync.pytest async def test__multiple_instance_registration(self): """ registering with multiple instance keys should update the key in instance_owners and active_instances """ - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - async with self._make_one(project="project-id") as client: + async with self._make_client(project="project-id") as client: async with client.get_table("instance_1", "table_1") as table_1: + assert table_1._register_instance_future is not None + if not CrossSync.is_async: + # give the background task time to run + table_1._register_instance_future.result() async with client.get_table("instance_2", "table_2") as table_2: + assert table_2._register_instance_future is not None + if not CrossSync.is_async: + # give the background task time to run + table_2._register_instance_future.result() instance_1_path = client._gapic_client.instance_path( client.project, "instance_1" ) @@ -831,12 +914,11 @@ async def test__multiple_instance_registration(self): assert len(client._instance_owners[instance_1_key]) == 0 assert len(client._instance_owners[instance_2_key]) == 0 - @pytest.mark.asyncio + @CrossSync.pytest async def test_get_table(self): - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey - client = self._make_one(project="project-id") + client = self._make_client(project="project-id") assert not client._active_instances expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -846,8 +928,8 @@ async def test_get_table(self): expected_table_id, expected_app_profile_id, ) - await asyncio.sleep(0) - assert isinstance(table, TableAsync) + await CrossSync.yield_to_event_loop() + assert isinstance(table, CrossSync.TestTable._get_target_class()) assert table.table_id == expected_table_id assert ( table.table_name @@ -867,14 +949,14 @@ async def test_get_table(self): assert client._instance_owners[instance_key] == {id(table)} await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_get_table_arg_passthrough(self): """ All arguments passed in get_table should be sent to constructor """ - async with self._make_one(project="project-id") as client: - with mock.patch( - "google.cloud.bigtable.data._async.client.TableAsync.__init__", + async with self._make_client(project="project-id") as client: + with mock.patch.object( + CrossSync.TestTable._get_target_class(), "__init__" ) as mock_constructor: mock_constructor.return_value = None assert not client._active_instances @@ -900,25 +982,26 @@ async def test_get_table_arg_passthrough(self): **expected_kwargs, ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_get_table_context_manager(self): - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey expected_table_id = "table-id" expected_instance_id = "instance-id" expected_app_profile_id = "app-profile-id" expected_project_id = "project-id" - with mock.patch.object(TableAsync, "close") as close_mock: - async with self._make_one(project=expected_project_id) as client: + with mock.patch.object( + CrossSync.TestTable._get_target_class(), "close" + ) as close_mock: + async with self._make_client(project=expected_project_id) as client: async with client.get_table( expected_instance_id, expected_table_id, expected_app_profile_id, ) as table: - await asyncio.sleep(0) - assert isinstance(table, TableAsync) + await CrossSync.yield_to_event_loop() + assert isinstance(table, CrossSync.TestTable._get_target_class()) assert table.table_id == expected_table_id assert ( table.table_name @@ -938,16 +1021,16 @@ async def test_get_table_context_manager(self): assert client._instance_owners[instance_key] == {id(table)} assert close_mock.call_count == 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_multiple_pool_sizes(self): # should be able to create multiple clients with different pool sizes without issue pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] for pool_size in pool_sizes: - client = self._make_one( + client = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False ) assert len(client._channel_refresh_tasks) == pool_size - client_duplicate = self._make_one( + client_duplicate = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False ) assert len(client_duplicate._channel_refresh_tasks) == pool_size @@ -955,14 +1038,10 @@ async def test_multiple_pool_sizes(self): await client.close() await client_duplicate.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_close(self): - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - ) - pool_size = 7 - client = self._make_one( + client = self._make_client( project="project-id", pool_size=pool_size, use_emulator=False ) assert len(client._channel_refresh_tasks) == pool_size @@ -970,36 +1049,36 @@ async def test_close(self): for task in client._channel_refresh_tasks: assert not task.done() with mock.patch.object( - PooledBigtableGrpcAsyncIOTransport, "close", AsyncMock() + CrossSync.PooledTransport, "close", CrossSync.Mock() ) as close_mock: await client.close() close_mock.assert_called_once() - close_mock.assert_awaited() + if CrossSync.is_async: + close_mock.assert_awaited() for task in tasks_list: assert task.done() - assert task.cancelled() - assert client._channel_refresh_tasks == [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_close_with_timeout(self): pool_size = 7 expected_timeout = 19 - client = self._make_one(project="project-id", pool_size=pool_size) + client = self._make_client(project="project-id", pool_size=pool_size) tasks = list(client._channel_refresh_tasks) - with mock.patch.object(asyncio, "wait_for", AsyncMock()) as wait_for_mock: + with mock.patch.object(CrossSync, "wait", CrossSync.Mock()) as wait_for_mock: await client.close(timeout=expected_timeout) wait_for_mock.assert_called_once() - wait_for_mock.assert_awaited() + if CrossSync.is_async: + wait_for_mock.assert_awaited() assert wait_for_mock.call_args[1]["timeout"] == expected_timeout client._channel_refresh_tasks = tasks await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_context_manager(self): # context manager should close the client cleanly - close_mock = AsyncMock() + close_mock = CrossSync.Mock() true_close = None - async with self._make_one(project="project-id") as client: + async with self._make_client(project="project-id") as client: true_close = client.close() client.close = close_mock for task in client._channel_refresh_tasks: @@ -1008,15 +1087,17 @@ async def test_context_manager(self): assert client._active_instances == set() close_mock.assert_not_called() close_mock.assert_called_once() - close_mock.assert_awaited() + if CrossSync.is_async: + close_mock.assert_awaited() # actually close the client await true_close + @CrossSync.drop_method def test_client_ctor_sync(self): # initializing client in a sync context should raise RuntimeError with pytest.warns(RuntimeWarning) as warnings: - client = _make_client(project="project-id", use_emulator=False) + client = self._make_client(project="project-id", use_emulator=False) expected_warning = [w for w in warnings if "client.py" in w.filename] assert len(expected_warning) == 1 assert ( @@ -1027,11 +1108,22 @@ def test_client_ctor_sync(self): assert client._channel_refresh_tasks == [] +@CrossSync.export_sync( + path="tests.unit.data._sync.test_client.TestTable", add_mapping_for_name="TestTable" +) class TestTableAsync: - @pytest.mark.asyncio + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + + @staticmethod + @CrossSync.convert + def _get_target_class(): + return CrossSync.Table + + @CrossSync.pytest async def test_table_ctor(self): - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -1042,10 +1134,10 @@ async def test_table_ctor(self): expected_read_rows_attempt_timeout = 0.5 expected_mutate_rows_operation_timeout = 2.5 expected_mutate_rows_attempt_timeout = 0.75 - client = _make_client() + client = self._make_client() assert not client._active_instances - table = TableAsync( + table = self._get_target_class()( client, expected_instance_id, expected_table_id, @@ -1057,7 +1149,7 @@ async def test_table_ctor(self): default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, ) - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id assert table.app_profile_id == expected_app_profile_id @@ -1086,30 +1178,28 @@ async def test_table_ctor(self): == expected_mutate_rows_attempt_timeout ) # ensure task reaches completion - await table._register_instance_task - assert table._register_instance_task.done() - assert not table._register_instance_task.cancelled() - assert table._register_instance_task.exception() is None + await table._register_instance_future + assert table._register_instance_future.done() + assert not table._register_instance_future.cancelled() + assert table._register_instance_future.exception() is None await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_table_ctor_defaults(self): """ should provide default timeout values and app_profile_id """ - from google.cloud.bigtable.data._async.client import TableAsync - expected_table_id = "table-id" expected_instance_id = "instance-id" - client = _make_client() + client = self._make_client() assert not client._active_instances - table = TableAsync( + table = self._get_target_class()( client, expected_instance_id, expected_table_id, ) - await asyncio.sleep(0) + await CrossSync.yield_to_event_loop() assert table.table_id == expected_table_id assert table.instance_id == expected_instance_id assert table.app_profile_id is None @@ -1122,14 +1212,12 @@ async def test_table_ctor_defaults(self): assert table.default_mutate_rows_attempt_timeout == 60 await client.close() - @pytest.mark.asyncio + @CrossSync.pytest async def test_table_ctor_invalid_timeout_values(self): """ bad timeout values should raise ValueError """ - from google.cloud.bigtable.data._async.client import TableAsync - - client = _make_client() + client = self._make_client() timeout_pairs = [ ("default_operation_timeout", "default_attempt_timeout"), @@ -1144,68 +1232,67 @@ async def test_table_ctor_invalid_timeout_values(self): ] for operation_timeout, attempt_timeout in timeout_pairs: with pytest.raises(ValueError) as e: - TableAsync(client, "", "", **{attempt_timeout: -1}) + self._get_target_class()(client, "", "", **{attempt_timeout: -1}) assert "attempt_timeout must be greater than 0" in str(e.value) with pytest.raises(ValueError) as e: - TableAsync(client, "", "", **{operation_timeout: -1}) + self._get_target_class()(client, "", "", **{operation_timeout: -1}) assert "operation_timeout must be greater than 0" in str(e.value) await client.close() + @CrossSync.drop_method def test_table_ctor_sync(self): # initializing client in a sync context should raise RuntimeError - from google.cloud.bigtable.data._async.client import TableAsync - client = mock.Mock() with pytest.raises(RuntimeError) as e: TableAsync(client, "instance-id", "table-id") assert e.match("TableAsync must be created within an async event loop context.") - @pytest.mark.asyncio + @CrossSync.pytest # iterate over all retryable rpcs @pytest.mark.parametrize( - "fn_name,fn_args,retry_fn_path,extra_retryables", + "fn_name,fn_args,is_stream,extra_retryables", [ ( "read_rows_stream", (ReadRowsQuery(),), - "google.api_core.retry.retry_target_stream_async", + True, (), ), ( "read_rows", (ReadRowsQuery(),), - "google.api_core.retry.retry_target_stream_async", + True, (), ), ( "read_row", (b"row_key",), - "google.api_core.retry.retry_target_stream_async", + True, (), ), ( "read_rows_sharded", ([ReadRowsQuery()],), - "google.api_core.retry.retry_target_stream_async", + True, (), ), ( "row_exists", (b"row_key",), - "google.api_core.retry.retry_target_stream_async", + True, (), ), - ("sample_row_keys", (), "google.api_core.retry.retry_target_async", ()), + ("sample_row_keys", (), False, ()), ( "mutate_row", (b"row_key", [mock.Mock()]), - "google.api_core.retry.retry_target_async", + False, (), ), ( "bulk_mutate_rows", - ([mutations.RowMutationEntry(b"key", [mock.Mock()])],), - "google.api_core.retry.retry_target_async", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + False, (_MutateRowsIncomplete,), ), ], @@ -1240,17 +1327,26 @@ async def test_customizable_retryable_errors( expected_retryables, fn_name, fn_args, - retry_fn_path, + is_stream, extra_retryables, ): """ Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - with mock.patch(retry_fn_path) as retry_fn_mock: - async with _make_client() as client: + retry_fn = "retry_target" + if is_stream: + retry_fn += "_stream" + if CrossSync.is_async: + retry_fn = f"CrossSync.{retry_fn}" + else: + retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" + with mock.patch( + f"google.cloud.bigtable.data._sync.cross_sync.{retry_fn}" + ) as retry_fn_mock: + async with self._make_client() as client: table = client.get_table("instance-id", "table-id") - expected_predicate = lambda a: a in expected_retryables # noqa + expected_predicate = expected_retryables.__contains__ retry_fn_mock.side_effect = RuntimeError("stop early") with mock.patch( "google.api_core.retry.if_exception_type" @@ -1292,18 +1388,19 @@ async def test_customizable_retryable_errors( ], ) @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio + @CrossSync.pytest + @CrossSync.convert async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): """check that all requests attach proper metadata headers""" - from google.cloud.bigtable.data import TableAsync - profile = "profile" if include_app_profile else None - with mock.patch( - f"google.cloud.bigtable_v2.BigtableAsyncClient.{gapic_fn}", mock.AsyncMock() + with mock.patch.object( + CrossSync.GapicClient, gapic_fn, CrossSync.Mock() ) as gapic_mock: gapic_mock.side_effect = RuntimeError("stop early") - async with _make_client() as client: - table = TableAsync(client, "instance-id", "table-id", profile) + async with self._make_client() as client: + table = self._get_target_class()( + client, "instance-id", "table-id", profile + ) try: test_fn = table.__getattribute__(fn_name) maybe_stream = await test_fn(*fn_args) @@ -1325,20 +1422,32 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ assert "app_profile_id=" not in goog_metadata -class TestReadRows: +@CrossSync.export_sync( + path="tests.unit.data._sync.test_client.TestReadRows", + add_mapping_for_name="TestReadRows", +) +class TestReadRowsAsync: """ Tests for table.read_rows and related methods. """ - def _make_table(self, *args, **kwargs): - from google.cloud.bigtable.data._async.client import TableAsync + @staticmethod + @CrossSync.convert + def _get_operation_class(): + return CrossSync._ReadRowsOperation + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + + @CrossSync.convert + def _make_table(self, *args, **kwargs): client_mock = mock.Mock() client_mock._register_instance.side_effect = ( - lambda *args, **kwargs: asyncio.sleep(0) + lambda *args, **kwargs: CrossSync.yield_to_event_loop() ) client_mock._remove_instance_registration.side_effect = ( - lambda *args, **kwargs: asyncio.sleep(0) + lambda *args, **kwargs: CrossSync.yield_to_event_loop() ) kwargs["instance_id"] = kwargs.get( "instance_id", args[0] if args else "instance" @@ -1348,7 +1457,7 @@ def _make_table(self, *args, **kwargs): ) client_mock._gapic_client.table_path.return_value = kwargs["table_id"] client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] - return TableAsync(client_mock, *args, **kwargs) + return CrossSync.TestTable._get_target_class()(client_mock, *args, **kwargs) def _make_stats(self): from google.cloud.bigtable_v2.types import RequestStats @@ -1379,6 +1488,7 @@ def _make_chunk(*args, **kwargs): return ReadRowsResponse.CellChunk(*args, **kwargs) @staticmethod + @CrossSync.convert async def _make_gapic_stream( chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0, @@ -1394,27 +1504,34 @@ def __init__(self, chunk_list, sleep_time): def __aiter__(self): return self + def __iter__(self): + return self + async def __anext__(self): self.idx += 1 if len(self.chunk_list) > self.idx: if sleep_time: - await asyncio.sleep(self.sleep_time) + await CrossSync.sleep(self.sleep_time) chunk = self.chunk_list[self.idx] if isinstance(chunk, Exception): raise chunk else: return ReadRowsResponse(chunks=[chunk]) - raise StopAsyncIteration + raise CrossSync.StopIteration + + def __next__(self): + return self.__anext__() def cancel(self): pass return mock_stream(chunk_list, sleep_time) + @CrossSync.convert async def execute_fn(self, table, *args, **kwargs): return await table.read_rows(*args, **kwargs) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows(self): query = ReadRowsQuery() chunks = [ @@ -1431,7 +1548,7 @@ async def test_read_rows(self): assert results[0].row_key == b"test_1" assert results[1].row_key == b"test_2" - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_stream(self): query = ReadRowsQuery() chunks = [ @@ -1450,7 +1567,7 @@ async def test_read_rows_stream(self): assert results[1].row_key == b"test_2" @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_query_matches_request(self, include_app_profile): from google.cloud.bigtable.data import RowRange from google.cloud.bigtable.data.row_filters import PassAllFilter @@ -1477,14 +1594,14 @@ async def test_read_rows_query_matches_request(self, include_app_profile): assert call_request == query_pb @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_timeout(self, operation_timeout): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows query = ReadRowsQuery() chunks = [self._make_chunk(row_key=b"test_1")] read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=1 + chunks, sleep_time=0.15 ) try: await table.read_rows(query, operation_timeout=operation_timeout) @@ -1502,7 +1619,7 @@ async def test_read_rows_timeout(self, operation_timeout): (0.05, 0.24, 5), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_attempt_timeout( self, per_request_t, operation_t, expected_num ): @@ -1565,7 +1682,7 @@ async def test_read_rows_attempt_timeout( core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_retryable_error(self, exc_type): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows @@ -1596,7 +1713,7 @@ async def test_read_rows_retryable_error(self, exc_type): InvalidChunk, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_non_retryable_error(self, exc_type): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows @@ -1610,18 +1727,17 @@ async def test_read_rows_non_retryable_error(self, exc_type): except exc_type as e: assert e == expected_error - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_revise_request(self): """ Ensure that _revise_request is called between retries """ - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable_v2.types import RowSet return_val = RowSet() with mock.patch.object( - _ReadRowsOperationAsync, "_revise_request_rowset" + self._get_operation_class(), "_revise_request_rowset" ) as revise_rowset: revise_rowset.return_value = return_val async with self._make_table() as table: @@ -1645,16 +1761,14 @@ async def test_read_rows_revise_request(self): revised_call = read_rows.call_args_list[1].args[0] assert revised_call.rows == return_val - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_default_timeouts(self): """ Ensure that the default timeouts are set on the read rows operation when not overridden """ - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - operation_timeout = 8 attempt_timeout = 4 - with mock.patch.object(_ReadRowsOperationAsync, "__init__") as mock_op: + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: mock_op.side_effect = RuntimeError("mock error") async with self._make_table( default_read_rows_operation_timeout=operation_timeout, @@ -1668,16 +1782,14 @@ async def test_read_rows_default_timeouts(self): assert kwargs["operation_timeout"] == operation_timeout assert kwargs["attempt_timeout"] == attempt_timeout - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_default_timeout_override(self): """ When timeouts are passed, they overwrite default values """ - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - operation_timeout = 8 attempt_timeout = 4 - with mock.patch.object(_ReadRowsOperationAsync, "__init__") as mock_op: + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: mock_op.side_effect = RuntimeError("mock error") async with self._make_table( default_operation_timeout=99, default_attempt_timeout=97 @@ -1694,10 +1806,10 @@ async def test_read_rows_default_timeout_override(self): assert kwargs["operation_timeout"] == operation_timeout assert kwargs["attempt_timeout"] == attempt_timeout - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_row(self): """Test reading a single row""" - async with _make_client() as client: + async with self._make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1722,10 +1834,10 @@ async def test_read_row(self): assert query.row_ranges == [] assert query.limit == 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_row_w_filter(self): """Test reading a single row with an added filter""" - async with _make_client() as client: + async with self._make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1755,10 +1867,10 @@ async def test_read_row_w_filter(self): assert query.limit == 1 assert query.filter == expected_filter - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_row_no_response(self): """should return None if row does not exist""" - async with _make_client() as client: + async with self._make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1790,10 +1902,10 @@ async def test_read_row_no_response(self): ([object(), object()], True), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_row_exists(self, return_value, expected_result): """Test checking for row existence""" - async with _make_client() as client: + async with self._make_client() as client: table = client.get_table("instance", "table") row_key = b"test_1" with mock.patch.object(table, "read_rows") as read_rows: @@ -1827,32 +1939,35 @@ async def test_row_exists(self, return_value, expected_result): assert query.filter._to_dict() == expected_filter -class TestReadRowsSharded: - @pytest.mark.asyncio +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRowsSharded") +class TestReadRowsShardedAsync: + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + + @CrossSync.pytest async def test_read_rows_sharded_empty_query(self): - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as exc: await table.read_rows_sharded([]) assert "empty sharded_query" in str(exc.value) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_multiple_queries(self): """ Test with multiple queries. Should return results from both """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( table.client._gapic_client, "read_rows" ) as read_rows: - read_rows.side_effect = ( - lambda *args, **kwargs: TestReadRows._make_gapic_stream( - [ - TestReadRows._make_chunk(row_key=k) - for k in args[0].rows.row_keys - ] - ) + read_rows.side_effect = lambda *args, **kwargs: CrossSync.TestReadRows._make_gapic_stream( + [ + CrossSync.TestReadRows._make_chunk(row_key=k) + for k in args[0].rows.row_keys + ] ) query_1 = ReadRowsQuery(b"test_1") query_2 = ReadRowsQuery(b"test_2") @@ -1862,19 +1977,19 @@ async def test_read_rows_sharded_multiple_queries(self): assert result[1].row_key == b"test_2" @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_multiple_queries_calls(self, n_queries): """ Each query should trigger a separate read_rows call """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: query_list = [ReadRowsQuery() for _ in range(n_queries)] await table.read_rows_sharded(query_list) assert read_rows.call_count == n_queries - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_errors(self): """ Errors should be exposed as ShardedReadRowsExceptionGroups @@ -1882,7 +1997,7 @@ async def test_read_rows_sharded_errors(self): from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.cloud.bigtable.data.exceptions import FailedQueryShardError - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = RuntimeError("mock error") @@ -1902,7 +2017,7 @@ async def test_read_rows_sharded_errors(self): assert exc.value.exceptions[1].index == 1 assert exc.value.exceptions[1].query == query_2 - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_concurrent(self): """ Ensure sharded requests are concurrent @@ -1913,7 +2028,7 @@ async def mock_call(*args, **kwargs): await asyncio.sleep(0.1) return [mock.Mock()] - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -1926,14 +2041,14 @@ async def mock_call(*args, **kwargs): # if run in sequence, we would expect this to take 1 second assert call_time < 0.2 - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_concurrency_limit(self): """ Only 10 queries should be processed concurrently. Others should be queued Should start a new query as soon as previous finishes """ - from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT assert _CONCURRENCY_LIMIT == 10 # change this test if this changes num_queries = 15 @@ -1951,7 +2066,7 @@ async def mock_call(*args, **kwargs): starting_timeout = 10 - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -1975,13 +2090,13 @@ async def mock_call(*args, **kwargs): idx = i + _CONCURRENCY_LIMIT assert rpc_start_list[idx] - (i * increment_time) < eps - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_expirary(self): """ If the operation times out before all shards complete, should raise a ShardedReadRowsExceptionGroup """ - from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.api_core.exceptions import DeadlineExceeded @@ -2001,7 +2116,7 @@ async def mock_call(*args, **kwargs): await asyncio.sleep(next_item) return [mock.Mock()] - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -2015,7 +2130,7 @@ async def mock_call(*args, **kwargs): # should keep successful queries assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_rows_sharded_negative_batch_timeout(self): """ try to run with batch that starts after operation timeout @@ -2026,10 +2141,10 @@ async def test_read_rows_sharded_negative_batch_timeout(self): from google.api_core.exceptions import DeadlineExceeded async def mock_call(*args, **kwargs): - await asyncio.sleep(0.05) + await CrossSync.sleep(0.05) return [mock.Mock()] - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object(table, "read_rows") as read_rows: read_rows.side_effect = mock_call @@ -2044,14 +2159,20 @@ async def mock_call(*args, **kwargs): ) -class TestSampleRowKeys: +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestSampleRowKeys") +class TestSampleRowKeysAsync: + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + + @CrossSync.convert async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): from google.cloud.bigtable_v2.types import SampleRowKeysResponse for value in sample_list: yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys(self): """ Test that method returns the expected key samples @@ -2061,10 +2182,10 @@ async def test_sample_row_keys(self): (b"test_2", 100), (b"test_3", 200), ] - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream(samples) result = await table.sample_row_keys() @@ -2076,12 +2197,12 @@ async def test_sample_row_keys(self): assert result[1] == samples[1] assert result[2] == samples[2] - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_bad_timeout(self): """ should raise error if timeout is negative """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.sample_row_keys(operation_timeout=-1) @@ -2090,11 +2211,11 @@ async def test_sample_row_keys_bad_timeout(self): await table.sample_row_keys(attempt_timeout=-1) assert "attempt_timeout must be greater than 0" in str(e.value) - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_default_timeout(self): """Should fallback to using table default operation_timeout""" expected_timeout = 99 - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table( "i", "t", @@ -2102,7 +2223,7 @@ async def test_sample_row_keys_default_timeout(self): default_attempt_timeout=expected_timeout, ) as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) result = await table.sample_row_keys() @@ -2111,7 +2232,7 @@ async def test_sample_row_keys_default_timeout(self): assert result == [] assert kwargs["retry"] is None - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_gapic_params(self): """ make sure arguments are propagated to gapic call as expected @@ -2120,12 +2241,12 @@ async def test_sample_row_keys_gapic_params(self): expected_profile = "test1" instance = "instance_name" table_id = "my_table" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table( instance, table_id, app_profile_id=expected_profile ) as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) await table.sample_row_keys(attempt_timeout=expected_timeout) @@ -2145,7 +2266,7 @@ async def test_sample_row_keys_gapic_params(self): core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_retryable_errors(self, retryable_exception): """ retryable errors should be retried until timeout @@ -2153,10 +2274,10 @@ async def test_sample_row_keys_retryable_errors(self, retryable_exception): from google.api_core.exceptions import DeadlineExceeded from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.side_effect = retryable_exception("mock") with pytest.raises(DeadlineExceeded) as e: @@ -2177,23 +2298,30 @@ async def test_sample_row_keys_retryable_errors(self, retryable_exception): core_exceptions.Aborted, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): """ non-retryable errors should cause a raise """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( - table.client._gapic_client, "sample_row_keys", AsyncMock() + table.client._gapic_client, "sample_row_keys", CrossSync.Mock() ) as sample_row_keys: sample_row_keys.side_effect = non_retryable_exception("mock") with pytest.raises(non_retryable_exception): await table.sample_row_keys() -class TestMutateRow: - @pytest.mark.asyncio +@CrossSync.export_sync( + path="tests.unit.data._sync.test_client.TestMutateRow", +) +class TestMutateRowAsync: + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + + @CrossSync.pytest @pytest.mark.parametrize( "mutation_arg", [ @@ -2214,7 +2342,7 @@ class TestMutateRow: async def test_mutate_row(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2249,12 +2377,12 @@ async def test_mutate_row(self, mutation_arg): core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_retryable_errors(self, retryable_exception): from google.api_core.exceptions import DeadlineExceeded from google.cloud.bigtable.data.exceptions import RetryExceptionGroup - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2277,14 +2405,14 @@ async def test_mutate_row_retryable_errors(self, retryable_exception): core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_non_idempotent_retryable_errors( self, retryable_exception ): """ Non-idempotent mutations should not be retried """ - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2310,9 +2438,9 @@ async def test_mutate_row_non_idempotent_retryable_errors( core_exceptions.Aborted, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_row" @@ -2331,14 +2459,14 @@ async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): ) @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_metadata(self, include_app_profile): """request should attach metadata headers""" profile = "profile" if include_app_profile else None - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("i", "t", app_profile_id=profile) as table: with mock.patch.object( - client._gapic_client, "mutate_row", AsyncMock() + client._gapic_client, "mutate_row", CrossSync.Mock() ) as read_rows: await table.mutate_row("rk", mock.Mock()) kwargs = read_rows.call_args_list[0].kwargs @@ -2355,16 +2483,24 @@ async def test_mutate_row_metadata(self, include_app_profile): assert "app_profile_id=" not in goog_metadata @pytest.mark.parametrize("mutations", [[], None]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_mutate_row_no_mutations(self, mutations): - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.mutate_row("key", mutations=mutations) assert e.value.args[0] == "No mutations provided" -class TestBulkMutateRows: +@CrossSync.export_sync( + path="tests.unit.data._sync.test_client.TestBulkMutateRows", +) +class TestBulkMutateRowsAsync: + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + + @CrossSync.convert async def _mock_response(self, response_list): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -2389,8 +2525,8 @@ async def generator(): return generator() - @pytest.mark.asyncio - @pytest.mark.asyncio + @CrossSync.pytest + @CrossSync.pytest @pytest.mark.parametrize( "mutation_arg", [ @@ -2413,7 +2549,7 @@ async def generator(): async def test_bulk_mutate_rows(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2434,10 +2570,10 @@ async def test_bulk_mutate_rows(self, mutation_arg): assert kwargs["timeout"] == expected_attempt_timeout assert kwargs["retry"] is None - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_rows_multiple_entries(self): """Test mutations with no errors""" - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2458,7 +2594,7 @@ async def test_bulk_mutate_rows_multiple_entries(self): assert kwargs["entries"][0] == entry_1._to_pb() assert kwargs["entries"][1] == entry_2._to_pb() - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "exception", [ @@ -2478,7 +2614,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_retryable( MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2503,7 +2639,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_retryable( cause.exceptions[-1], core_exceptions.DeadlineExceeded ) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "exception", [ @@ -2524,7 +2660,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable( MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2551,7 +2687,7 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable( core_exceptions.ServiceUnavailable, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_idempotent_retryable_request_errors( self, retryable_exception ): @@ -2564,7 +2700,7 @@ async def test_bulk_mutate_idempotent_retryable_request_errors( MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2585,7 +2721,7 @@ async def test_bulk_mutate_idempotent_retryable_request_errors( assert isinstance(cause, RetryExceptionGroup) assert isinstance(cause.exceptions[0], retryable_exception) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "retryable_exception", [ @@ -2602,7 +2738,7 @@ async def test_bulk_mutate_rows_non_idempotent_retryable_errors( MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2634,7 +2770,7 @@ async def test_bulk_mutate_rows_non_idempotent_retryable_errors( ValueError, ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception): """ If the request fails with a non-retryable error, mutations should not be retried @@ -2644,7 +2780,7 @@ async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_excepti MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2664,7 +2800,7 @@ async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_excepti cause = failed_exception.__cause__ assert isinstance(cause, non_retryable_exception) - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_error_index(self): """ Test partial failure, partial success. Errors should be associated with the correct index @@ -2680,7 +2816,7 @@ async def test_bulk_mutate_error_index(self): MutationsExceptionGroup, ) - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "mutate_rows" @@ -2715,14 +2851,14 @@ async def test_bulk_mutate_error_index(self): assert isinstance(cause.exceptions[1], DeadlineExceeded) assert isinstance(cause.exceptions[2], FailedPrecondition) - @pytest.mark.asyncio + @CrossSync.pytest async def test_bulk_mutate_error_recovery(self): """ If an error occurs, then resolves, no exception should be raised """ from google.api_core.exceptions import DeadlineExceeded - async with _make_client(project="project") as client: + async with self._make_client(project="project") as client: table = client.get_table("instance", "table") with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: # fail with a retryable error, then a non-retryable one @@ -2740,14 +2876,19 @@ async def test_bulk_mutate_error_recovery(self): await table.bulk_mutate_rows(entries, operation_timeout=1000) -class TestCheckAndMutateRow: +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestCheckAndMutateRow") +class TestCheckAndMutateRowAsync: + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + @pytest.mark.parametrize("gapic_result", [True, False]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate(self, gapic_result): from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse app_profile = "app_profile_id" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table( "instance", "table", app_profile_id=app_profile ) as table: @@ -2784,10 +2925,10 @@ async def test_check_and_mutate(self, gapic_result): assert kwargs["timeout"] == operation_timeout assert kwargs["retry"] is None - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate_bad_timeout(self): """Should raise error if operation_timeout < 0""" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.check_and_mutate_row( @@ -2799,13 +2940,13 @@ async def test_check_and_mutate_bad_timeout(self): ) assert str(e.value) == "operation_timeout must be greater than 0" - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate_single_mutations(self): """if single mutations are passed, they should be internally wrapped in a list""" from google.cloud.bigtable.data.mutations import SetCell from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2825,7 +2966,7 @@ async def test_check_and_mutate_single_mutations(self): assert kwargs["true_mutations"] == [true_mutation._to_pb()] assert kwargs["false_mutations"] == [false_mutation._to_pb()] - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate_predicate_object(self): """predicate filter should be passed to gapic request""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse @@ -2833,7 +2974,7 @@ async def test_check_and_mutate_predicate_object(self): mock_predicate = mock.Mock() predicate_pb = {"predicate": "dict"} mock_predicate._to_pb.return_value = predicate_pb - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2851,7 +2992,7 @@ async def test_check_and_mutate_predicate_object(self): assert mock_predicate._to_pb.call_count == 1 assert kwargs["retry"] is None - @pytest.mark.asyncio + @CrossSync.pytest async def test_check_and_mutate_mutations_parsing(self): """mutations objects should be converted to protos""" from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse @@ -2861,7 +3002,7 @@ async def test_check_and_mutate_mutations_parsing(self): for idx, mutation in enumerate(mutations): mutation._to_pb.return_value = f"fake {idx}" mutations.append(DeleteAllFromRow()) - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "check_and_mutate_row" @@ -2888,7 +3029,12 @@ async def test_check_and_mutate_mutations_parsing(self): ) -class TestReadModifyWriteRow: +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadModifyWriteRow") +class TestReadModifyWriteRowAsync: + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + @pytest.mark.parametrize( "call_rules,expected_rules", [ @@ -2910,12 +3056,12 @@ class TestReadModifyWriteRow: ), ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): """ Test that the gapic call is called with given rules """ - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -2927,21 +3073,21 @@ async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules assert found_kwargs["retry"] is None @pytest.mark.parametrize("rules", [[], None]) - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_no_rules(self, rules): - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table") as table: with pytest.raises(ValueError) as e: await table.read_modify_write_row("key", rules=rules) assert e.value.args[0] == "rules must contain at least one item" - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_call_defaults(self): instance = "instance1" table_id = "table1" project = "project1" row_key = "row_key1" - async with _make_client(project=project) as client: + async with self._make_client(project=project) as client: async with client.get_table(instance, table_id) as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -2957,12 +3103,12 @@ async def test_read_modify_write_call_defaults(self): assert kwargs["row_key"] == row_key.encode() assert kwargs["timeout"] > 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_call_overrides(self): row_key = b"row_key1" expected_timeout = 12345 profile_id = "profile1" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table( "instance", "table_id", app_profile_id=profile_id ) as table: @@ -2980,10 +3126,10 @@ async def test_read_modify_write_call_overrides(self): assert kwargs["row_key"] == row_key assert kwargs["timeout"] == expected_timeout - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_string_key(self): row_key = "string_row_key1" - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table_id") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" @@ -2993,7 +3139,7 @@ async def test_read_modify_write_string_key(self): kwargs = mock_gapic.call_args_list[0][1] assert kwargs["row_key"] == row_key.encode() - @pytest.mark.asyncio + @CrossSync.pytest async def test_read_modify_write_row_building(self): """ results from gapic call should be used to construct row @@ -3003,7 +3149,7 @@ async def test_read_modify_write_row_building(self): from google.cloud.bigtable_v2.types import Row as RowPB mock_response = ReadModifyWriteRowResponse(row=RowPB()) - async with _make_client() as client: + async with self._make_client() as client: async with client.get_table("instance", "table_id") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index cca7c9824..fcd425273 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -14,33 +14,39 @@ import pytest import asyncio +import time import google.api_core.exceptions as core_exceptions +import google.api_core.retry from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.cloud.bigtable.data import TABLE_DEFAULT +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + # try/except added for compatibility with python < 3.8 try: from unittest import mock - from unittest.mock import AsyncMock except ImportError: # pragma: NO COVER import mock # type: ignore - from mock import AsyncMock # type: ignore - - -def _make_mutation(count=1, size=1): - mutation = mock.Mock() - mutation.size.return_value = size - mutation.mutations = [mock.Mock()] * count - return mutation +@CrossSync.export_sync( + path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl" +) class Test_FlowControl: + @staticmethod + @CrossSync.convert + def _target_class(): + return CrossSync._FlowControl + def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): - from google.cloud.bigtable.data._async.mutations_batcher import ( - _FlowControlAsync, - ) + return self._target_class()(max_mutation_count, max_mutation_bytes) - return _FlowControlAsync(max_mutation_count, max_mutation_bytes) + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation def test_ctor(self): max_mutation_count = 9 @@ -50,7 +56,7 @@ def test_ctor(self): assert instance._max_mutation_bytes == max_mutation_bytes assert instance._in_flight_mutation_count == 0 assert instance._in_flight_mutation_bytes == 0 - assert isinstance(instance._capacity_condition, asyncio.Condition) + assert isinstance(instance._capacity_condition, CrossSync.Condition) def test_ctor_invalid_values(self): """Test that values are positive, and fit within expected limits""" @@ -110,7 +116,7 @@ def test__has_capacity( instance._in_flight_mutation_bytes = existing_size assert instance._has_capacity(new_count, new_size) == expected - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "existing_count,existing_size,added_count,added_size,new_count,new_size", [ @@ -138,12 +144,12 @@ async def test_remove_from_flow_value_update( instance = self._make_one() instance._in_flight_mutation_count = existing_count instance._in_flight_mutation_bytes = existing_size - mutation = _make_mutation(added_count, added_size) + mutation = self._make_mutation(added_count, added_size) await instance.remove_from_flow(mutation) assert instance._in_flight_mutation_count == new_count assert instance._in_flight_mutation_bytes == new_size - @pytest.mark.asyncio + @CrossSync.pytest async def test__remove_from_flow_unlock(self): """capacity condition should notify after mutation is complete""" instance = self._make_one(10, 10) @@ -156,36 +162,50 @@ async def task_routine(): lambda: instance._has_capacity(1, 1) ) - task = asyncio.create_task(task_routine()) - await asyncio.sleep(0.05) + if CrossSync.is_async: + # for async class, build task to test flow unlock + task = asyncio.create_task(task_routine()) + + def task_alive(): + return not task.done() + + else: + # this branch will be tested in sync version of this test + import threading + + thread = threading.Thread(target=task_routine) + thread.start() + task_alive = thread.is_alive + await CrossSync.sleep(0.05) # should be blocked due to capacity - assert task.done() is False + assert task_alive() is True # try changing size - mutation = _make_mutation(count=0, size=5) + mutation = self._make_mutation(count=0, size=5) + await instance.remove_from_flow([mutation]) - await asyncio.sleep(0.05) + await CrossSync.sleep(0.05) assert instance._in_flight_mutation_count == 10 assert instance._in_flight_mutation_bytes == 5 - assert task.done() is False + assert task_alive() is True # try changing count instance._in_flight_mutation_bytes = 10 - mutation = _make_mutation(count=5, size=0) + mutation = self._make_mutation(count=5, size=0) await instance.remove_from_flow([mutation]) - await asyncio.sleep(0.05) + await CrossSync.sleep(0.05) assert instance._in_flight_mutation_count == 5 assert instance._in_flight_mutation_bytes == 10 - assert task.done() is False + assert task_alive() is True # try changing both instance._in_flight_mutation_count = 10 - mutation = _make_mutation(count=5, size=5) + mutation = self._make_mutation(count=5, size=5) await instance.remove_from_flow([mutation]) - await asyncio.sleep(0.05) + await CrossSync.sleep(0.05) assert instance._in_flight_mutation_count == 5 assert instance._in_flight_mutation_bytes == 5 # task should be complete - assert task.done() is True + assert task_alive() is False - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "mutations,count_cap,size_cap,expected_results", [ @@ -210,7 +230,7 @@ async def test_add_to_flow(self, mutations, count_cap, size_cap, expected_result """ Test batching with various flow control settings """ - mutation_objs = [_make_mutation(count=m[0], size=m[1]) for m in mutations] + mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] instance = self._make_one(count_cap, size_cap) i = 0 async for batch in instance.add_to_flow(mutation_objs): @@ -226,7 +246,7 @@ async def test_add_to_flow(self, mutations, count_cap, size_cap, expected_result i += 1 assert i == len(expected_results) - @pytest.mark.asyncio + @CrossSync.pytest @pytest.mark.parametrize( "mutations,max_limit,expected_results", [ @@ -242,11 +262,12 @@ async def test_add_to_flow_max_mutation_limits( Test flow control running up against the max API limit Should submit request early, even if the flow control has room for more """ - with mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT", - max_limit, - ): - mutation_objs = [_make_mutation(count=m[0], size=m[1]) for m in mutations] + subpath = "_async" if CrossSync.is_async else "_sync" + path = f"google.cloud.bigtable.data.{subpath}.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT" + with mock.patch(path, max_limit): + mutation_objs = [ + self._make_mutation(count=m[0], size=m[1]) for m in mutations + ] # flow control has no limits except API restrictions instance = self._make_one(float("inf"), float("inf")) i = 0 @@ -263,14 +284,14 @@ async def test_add_to_flow_max_mutation_limits( i += 1 assert i == len(expected_results) - @pytest.mark.asyncio + @CrossSync.pytest async def test_add_to_flow_oversize(self): """ mutations over the flow control limits should still be accepted """ instance = self._make_one(2, 3) - large_size_mutation = _make_mutation(count=1, size=10) - large_count_mutation = _make_mutation(count=10, size=1) + large_size_mutation = self._make_mutation(count=1, size=10) + large_count_mutation = self._make_mutation(count=10, size=1) results = [out async for out in instance.add_to_flow([large_size_mutation])] assert len(results) == 1 await instance.remove_from_flow(results[0]) @@ -280,13 +301,13 @@ async def test_add_to_flow_oversize(self): assert len(count_results) == 1 +@CrossSync.export_sync( + path="tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" +) class TestMutationsBatcherAsync: + @CrossSync.convert def _get_target_class(self): - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) - - return MutationsBatcherAsync + return CrossSync.MutationsBatcher def _make_one(self, table=None, **kwargs): from google.api_core.exceptions import DeadlineExceeded @@ -303,132 +324,140 @@ def _make_one(self, table=None, **kwargs): return self._get_target_class()(table, **kwargs) - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer" - ) - @pytest.mark.asyncio - async def test_ctor_defaults(self, flush_timer_mock): - flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = [Exception] - async with self._make_one(table) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._max_mutation_count == 100000 - assert instance._flow_control._max_mutation_bytes == 104857600 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert ( - instance._operation_timeout - == table.default_mutate_rows_operation_timeout - ) - assert ( - instance._attempt_timeout == table.default_mutate_rows_attempt_timeout - ) - assert ( - instance._retryable_errors == table.default_mutate_rows_retryable_errors - ) - await asyncio.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == 5 - assert isinstance(instance._flush_timer, asyncio.Future) + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer", - ) - @pytest.mark.asyncio - async def test_ctor_explicit(self, flush_timer_mock): + @CrossSync.pytest + async def test_ctor_defaults(self): + with mock.patch.object( + self._get_target_class(), "_timer_routine", return_value=CrossSync.Future() + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = [Exception] + async with self._make_one(table) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._max_mutation_count == 100000 + assert instance._flow_control._max_mutation_bytes == 104857600 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert ( + instance._operation_timeout + == table.default_mutate_rows_operation_timeout + ) + assert ( + instance._attempt_timeout + == table.default_mutate_rows_attempt_timeout + ) + assert ( + instance._retryable_errors + == table.default_mutate_rows_retryable_errors + ) + await CrossSync.yield_to_event_loop() + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == 5 + assert isinstance(instance._flush_timer, CrossSync.Future) + + @CrossSync.pytest + async def test_ctor_explicit(self): """Test with explicit parameters""" - flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) - table = mock.Mock() - flush_interval = 20 - flush_limit_count = 17 - flush_limit_bytes = 19 - flow_control_max_mutation_count = 1001 - flow_control_max_bytes = 12 - operation_timeout = 11 - attempt_timeout = 2 - retryable_errors = [Exception] - async with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - flow_control_max_mutation_count=flow_control_max_mutation_count, - flow_control_max_bytes=flow_control_max_bytes, - batch_operation_timeout=operation_timeout, - batch_attempt_timeout=attempt_timeout, - batch_retryable_errors=retryable_errors, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._flush_jobs == set() - assert len(instance._staged_entries) == 0 - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert ( - instance._flow_control._max_mutation_count - == flow_control_max_mutation_count - ) - assert instance._flow_control._max_mutation_bytes == flow_control_max_bytes - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - assert instance._operation_timeout == operation_timeout - assert instance._attempt_timeout == attempt_timeout - assert instance._retryable_errors == retryable_errors - await asyncio.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] == flush_interval - assert isinstance(instance._flush_timer, asyncio.Future) - - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._start_flush_timer" - ) - @pytest.mark.asyncio - async def test_ctor_no_flush_limits(self, flush_timer_mock): + with mock.patch.object( + self._get_target_class(), "_timer_routine", return_value=CrossSync.Future() + ) as flush_timer_mock: + table = mock.Mock() + flush_interval = 20 + flush_limit_count = 17 + flush_limit_bytes = 19 + flow_control_max_mutation_count = 1001 + flow_control_max_bytes = 12 + operation_timeout = 11 + attempt_timeout = 2 + retryable_errors = [Exception] + async with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=operation_timeout, + batch_attempt_timeout=attempt_timeout, + batch_retryable_errors=retryable_errors, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert ( + instance._flow_control._max_mutation_count + == flow_control_max_mutation_count + ) + assert ( + instance._flow_control._max_mutation_bytes == flow_control_max_bytes + ) + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert instance._operation_timeout == operation_timeout + assert instance._attempt_timeout == attempt_timeout + assert instance._retryable_errors == retryable_errors + await CrossSync.yield_to_event_loop() + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == flush_interval + assert isinstance(instance._flush_timer, CrossSync.Future) + + @CrossSync.pytest + async def test_ctor_no_flush_limits(self): """Test with None for flush limits""" - flush_timer_mock.return_value = asyncio.create_task(asyncio.sleep(0)) - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 10 - table.default_mutate_rows_attempt_timeout = 8 - table.default_mutate_rows_retryable_errors = () - flush_interval = None - flush_limit_count = None - flush_limit_bytes = None - async with self._make_one( - table, - flush_interval=flush_interval, - flush_limit_mutation_count=flush_limit_count, - flush_limit_bytes=flush_limit_bytes, - ) as instance: - assert instance._table == table - assert instance.closed is False - assert instance._staged_entries == [] - assert len(instance._oldest_exceptions) == 0 - assert len(instance._newest_exceptions) == 0 - assert instance._exception_list_limit == 10 - assert instance._exceptions_since_last_raise == 0 - assert instance._flow_control._in_flight_mutation_count == 0 - assert instance._flow_control._in_flight_mutation_bytes == 0 - assert instance._entries_processed_since_last_raise == 0 - await asyncio.sleep(0) - assert flush_timer_mock.call_count == 1 - assert flush_timer_mock.call_args[0][0] is None - assert isinstance(instance._flush_timer, asyncio.Future) + with mock.patch.object( + self._get_target_class(), "_timer_routine", return_value=CrossSync.Future() + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = () + flush_interval = None + flush_limit_count = None + flush_limit_bytes = None + async with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._staged_entries == [] + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + await CrossSync.yield_to_event_loop() + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] is None + assert isinstance(instance._flush_timer, CrossSync.Future) - @pytest.mark.asyncio + @CrossSync.pytest async def test_ctor_invalid_values(self): """Test that timeout values are positive, and fit within expected limits""" with pytest.raises(ValueError) as e: @@ -438,24 +467,21 @@ async def test_ctor_invalid_values(self): self._make_one(batch_attempt_timeout=-1) assert "attempt_timeout must be greater than 0" in str(e.value) + @CrossSync.convert def test_default_argument_consistency(self): """ We supply default arguments in MutationsBatcherAsync.__init__, and in table.mutations_batcher. Make sure any changes to defaults are applied to both places """ - from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) import inspect get_batcher_signature = dict( - inspect.signature(TableAsync.mutations_batcher).parameters + inspect.signature(CrossSync.Table.mutations_batcher).parameters ) get_batcher_signature.pop("self") batcher_init_signature = dict( - inspect.signature(MutationsBatcherAsync).parameters + inspect.signature(self._get_target_class()).parameters ) batcher_init_signature.pop("table") # both should have same number of arguments @@ -470,97 +496,96 @@ def test_default_argument_consistency(self): == batcher_init_signature[arg_name].default ) - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) - @pytest.mark.asyncio - async def test__start_flush_timer_w_None(self, flush_mock): - """Empty timer should return immediately""" - async with self._make_one() as instance: - with mock.patch("asyncio.sleep") as sleep_mock: - await instance._start_flush_timer(None) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 + @CrossSync.pytest + @pytest.mark.parametrize("input_val", [None, 0, -1]) + async def test__start_flush_timer_w_empty_input(self, input_val): + """Empty/invalid timer should return immediately""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + # mock different method depending on sync vs async + async with self._make_one() as instance: + if CrossSync.is_async: + sleep_obj, sleep_method = asyncio, "wait_for" + else: + sleep_obj, sleep_method = instance._closed, "wait" + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + result = await instance._timer_routine(input_val) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + assert result is None - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) - @pytest.mark.asyncio - async def test__start_flush_timer_call_when_closed(self, flush_mock): + @CrossSync.pytest + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + async def test__start_flush_timer_call_when_closed( + self, + ): """closed batcher's timer should return immediately""" - async with self._make_one() as instance: - await instance.close() - flush_mock.reset_mock() - with mock.patch("asyncio.sleep") as sleep_mock: - await instance._start_flush_timer(1) - assert sleep_mock.call_count == 0 - assert flush_mock.call_count == 0 + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + async with self._make_one() as instance: + await instance.close() + flush_mock.reset_mock() + # mock different method depending on sync vs async + if CrossSync.is_async: + sleep_obj, sleep_method = asyncio, "wait_for" + else: + sleep_obj, sleep_method = instance._closed, "wait" + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + await instance._timer_routine(10) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) - @pytest.mark.asyncio - async def test__flush_timer(self, flush_mock): + @CrossSync.pytest + @pytest.mark.parametrize("num_staged", [0, 1, 10]) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + async def test__flush_timer(self, num_staged): """Timer should continue to call _schedule_flush in a loop""" - expected_sleep = 12 - async with self._make_one(flush_interval=expected_sleep) as instance: - instance._staged_entries = [mock.Mock()] - loop_num = 3 - with mock.patch("asyncio.sleep") as sleep_mock: - sleep_mock.side_effect = [None] * loop_num + [asyncio.CancelledError()] - try: - await instance._flush_timer - except asyncio.CancelledError: - pass - assert sleep_mock.call_count == loop_num + 1 - sleep_mock.assert_called_with(expected_sleep) - assert flush_mock.call_count == loop_num - - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) - @pytest.mark.asyncio - async def test__flush_timer_no_mutations(self, flush_mock): - """Timer should not flush if no new mutations have been staged""" - expected_sleep = 12 - async with self._make_one(flush_interval=expected_sleep) as instance: - loop_num = 3 - with mock.patch("asyncio.sleep") as sleep_mock: - sleep_mock.side_effect = [None] * loop_num + [asyncio.CancelledError()] - try: - await instance._flush_timer - except asyncio.CancelledError: - pass - assert sleep_mock.call_count == loop_num + 1 - sleep_mock.assert_called_with(expected_sleep) - assert flush_mock.call_count == 0 + from google.cloud.bigtable.data._sync.cross_sync import CrossSync - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher.MutationsBatcherAsync._schedule_flush" - ) - @pytest.mark.asyncio - async def test__flush_timer_close(self, flush_mock): + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + expected_sleep = 12 + async with self._make_one(flush_interval=expected_sleep) as instance: + loop_num = 3 + instance._staged_entries = [mock.Mock()] * num_staged + with mock.patch.object(CrossSync, "event_wait") as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] + with pytest.raises(TabError): + await self._get_target_class()._timer_routine( + instance, expected_sleep + ) + if CrossSync.is_async: + # replace with np-op so there are no issues on close + instance._flush_timer = CrossSync.Future() + assert sleep_mock.call_count == loop_num + 1 + sleep_kwargs = sleep_mock.call_args[1] + assert sleep_kwargs["timeout"] == expected_sleep + assert flush_mock.call_count == (0 if num_staged == 0 else loop_num) + + @CrossSync.pytest + async def test__flush_timer_close(self): """Timer should continue terminate after close""" - async with self._make_one() as instance: - with mock.patch("asyncio.sleep"): + with mock.patch.object(self._get_target_class(), "_schedule_flush"): + async with self._make_one() as instance: # let task run in background - await asyncio.sleep(0.5) assert instance._flush_timer.done() is False # close the batcher await instance.close() - await asyncio.sleep(0.1) # task should be complete assert instance._flush_timer.done() is True - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_closed(self): """Should raise exception""" + instance = self._make_one() + await instance.close() with pytest.raises(RuntimeError): - instance = self._make_one() - await instance.close() await instance.append(mock.Mock()) - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_wrong_mutation(self): """ Mutation objects should raise an exception. @@ -574,13 +599,13 @@ async def test_append_wrong_mutation(self): await instance.append(DeleteAllFromRow()) assert str(e.value) == expected_error - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_outside_flow_limits(self): """entries larger than mutation limits are still processed""" async with self._make_one( flow_control_max_mutation_count=1, flow_control_max_bytes=1 ) as instance: - oversized_entry = _make_mutation(count=0, size=2) + oversized_entry = self._make_mutation(count=0, size=2) await instance.append(oversized_entry) assert instance._staged_entries == [oversized_entry] assert instance._staged_count == 0 @@ -589,25 +614,21 @@ async def test_append_outside_flow_limits(self): async with self._make_one( flow_control_max_mutation_count=1, flow_control_max_bytes=1 ) as instance: - overcount_entry = _make_mutation(count=2, size=0) + overcount_entry = self._make_mutation(count=2, size=0) await instance.append(overcount_entry) assert instance._staged_entries == [overcount_entry] assert instance._staged_count == 2 assert instance._staged_bytes == 0 instance._staged_entries = [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_flush_runs_after_limit_hit(self): """ If the user appends a bunch of entries above the flush limits back-to-back, it should still flush in a single task """ - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) - with mock.patch.object( - MutationsBatcherAsync, "_execute_mutate_rows" + self._get_target_class(), "_execute_mutate_rows" ) as op_mock: async with self._make_one(flush_limit_bytes=100) as instance: # mock network calls @@ -616,13 +637,13 @@ async def mock_call(*args, **kwargs): op_mock.side_effect = mock_call # append a mutation just under the size limit - await instance.append(_make_mutation(size=99)) + await instance.append(self._make_mutation(size=99)) # append a bunch of entries back-to-back in a loop num_entries = 10 for _ in range(num_entries): - await instance.append(_make_mutation(size=1)) + await instance.append(self._make_mutation(size=1)) # let any flush jobs finish - await asyncio.gather(*instance._flush_jobs) + await instance._wait_for_batch_results(*instance._flush_jobs) # should have only flushed once, with large mutation and first mutation in loop assert op_mock.call_count == 1 sent_batch = op_mock.call_args[0][0] @@ -642,7 +663,8 @@ async def mock_call(*args, **kwargs): (1, 1, 0, 0, False), ], ) - @pytest.mark.asyncio + @CrossSync.pytest + @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_append( self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush ): @@ -653,7 +675,7 @@ async def test_append( assert instance._staged_count == 0 assert instance._staged_bytes == 0 assert instance._staged_entries == [] - mutation = _make_mutation(count=mutation_count, size=mutation_bytes) + mutation = self._make_mutation(count=mutation_count, size=mutation_bytes) with mock.patch.object(instance, "_schedule_flush") as flush_mock: await instance.append(mutation) assert flush_mock.call_count == bool(expect_flush) @@ -662,7 +684,7 @@ async def test_append( assert instance._staged_entries == [mutation] instance._staged_entries = [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_append_multiple_sequentially(self): """Append multiple mutations""" async with self._make_one( @@ -671,7 +693,7 @@ async def test_append_multiple_sequentially(self): assert instance._staged_count == 0 assert instance._staged_bytes == 0 assert instance._staged_entries == [] - mutation = _make_mutation(count=2, size=3) + mutation = self._make_mutation(count=2, size=3) with mock.patch.object(instance, "_schedule_flush") as flush_mock: await instance.append(mutation) assert flush_mock.call_count == 0 @@ -690,7 +712,7 @@ async def test_append_multiple_sequentially(self): assert len(instance._staged_entries) == 3 instance._staged_entries = [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_flush_flow_control_concurrent_requests(self): """ requests should happen in parallel if flow control breaks up single flush into batches @@ -698,14 +720,14 @@ async def test_flush_flow_control_concurrent_requests(self): import time num_calls = 10 - fake_mutations = [_make_mutation(count=1) for _ in range(num_calls)] + fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] async with self._make_one(flow_control_max_mutation_count=1) as instance: with mock.patch.object( - instance, "_execute_mutate_rows", AsyncMock() + instance, "_execute_mutate_rows", CrossSync.Mock() ) as op_mock: # mock network calls async def mock_call(*args, **kwargs): - await asyncio.sleep(0.1) + await CrossSync.sleep(0.1) return [] op_mock.side_effect = mock_call @@ -713,15 +735,15 @@ async def mock_call(*args, **kwargs): # flush one large batch, that will be broken up into smaller batches instance._staged_entries = fake_mutations instance._schedule_flush() - await asyncio.sleep(0.01) + await CrossSync.sleep(0.01) # make room for new mutations for i in range(num_calls): await instance._flow_control.remove_from_flow( - [_make_mutation(count=1)] + [self._make_mutation(count=1)] ) - await asyncio.sleep(0.01) + await CrossSync.sleep(0.01) # allow flushes to complete - await asyncio.gather(*instance._flush_jobs) + await instance._wait_for_batch_results(*instance._flush_jobs) duration = time.monotonic() - start_time assert len(instance._oldest_exceptions) == 0 assert len(instance._newest_exceptions) == 0 @@ -729,7 +751,7 @@ async def mock_call(*args, **kwargs): assert duration < 0.5 assert op_mock.call_count == num_calls - @pytest.mark.asyncio + @CrossSync.pytest async def test_schedule_flush_no_mutations(self): """schedule flush should return None if no staged mutations""" async with self._make_one() as instance: @@ -738,11 +760,15 @@ async def test_schedule_flush_no_mutations(self): assert instance._schedule_flush() is None assert flush_mock.call_count == 0 - @pytest.mark.asyncio + @CrossSync.pytest + @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_schedule_flush_with_mutations(self): """if new mutations exist, should add a new flush task to _flush_jobs""" async with self._make_one() as instance: with mock.patch.object(instance, "_flush_internal") as flush_mock: + if not CrossSync.is_async: + # simulate operation + flush_mock.side_effect = lambda x: time.sleep(0.1) for i in range(1, 4): mutation = mock.Mock() instance._staged_entries = [mutation] @@ -753,9 +779,10 @@ async def test_schedule_flush_with_mutations(self): assert instance._staged_entries == [] assert instance._staged_count == 0 assert instance._staged_bytes == 0 - assert flush_mock.call_count == i + assert flush_mock.call_count == 1 + flush_mock.reset_mock() - @pytest.mark.asyncio + @CrossSync.pytest async def test__flush_internal(self): """ _flush_internal should: @@ -775,7 +802,7 @@ async def gen(x): yield x flow_mock.side_effect = lambda x: gen(x) - mutations = [_make_mutation(count=1, size=1)] * num_entries + mutations = [self._make_mutation(count=1, size=1)] * num_entries await instance._flush_internal(mutations) assert instance._entries_processed_since_last_raise == num_entries assert execute_mock.call_count == 1 @@ -783,20 +810,28 @@ async def gen(x): instance._oldest_exceptions.clear() instance._newest_exceptions.clear() - @pytest.mark.asyncio + @CrossSync.pytest async def test_flush_clears_job_list(self): """ a job should be added to _flush_jobs when _schedule_flush is called, and removed when it completes """ async with self._make_one() as instance: - with mock.patch.object(instance, "_flush_internal", AsyncMock()): - mutations = [_make_mutation(count=1, size=1)] + with mock.patch.object( + instance, "_flush_internal", CrossSync.Mock() + ) as flush_mock: + if not CrossSync.is_async: + # simulate operation + flush_mock.side_effect = lambda x: time.sleep(0.1) + mutations = [self._make_mutation(count=1, size=1)] instance._staged_entries = mutations assert instance._flush_jobs == set() new_job = instance._schedule_flush() assert instance._flush_jobs == {new_job} - await new_job + if CrossSync.is_async: + await new_job + else: + new_job.result() assert instance._flush_jobs == set() @pytest.mark.parametrize( @@ -811,7 +846,7 @@ async def test_flush_clears_job_list(self): (10, 20, 20), # should cap at 20 ], ) - @pytest.mark.asyncio + @CrossSync.pytest async def test__flush_internal_with_errors( self, num_starting, num_new_errors, expected_total_errors ): @@ -836,7 +871,7 @@ async def gen(x): yield x flow_mock.side_effect = lambda x: gen(x) - mutations = [_make_mutation(count=1, size=1)] * num_entries + mutations = [self._make_mutation(count=1, size=1)] * num_entries await instance._flush_internal(mutations) assert instance._entries_processed_since_last_raise == num_entries assert execute_mock.call_count == 1 @@ -853,6 +888,7 @@ async def gen(x): instance._oldest_exceptions.clear() instance._newest_exceptions.clear() + @CrossSync.convert async def _mock_gapic_return(self, num=5): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -866,11 +902,11 @@ async def gen(num): return gen(num) - @pytest.mark.asyncio + @CrossSync.pytest async def test_timer_flush_end_to_end(self): """Flush should automatically trigger after flush_interval""" - num_nutations = 10 - mutations = [_make_mutation(count=2, size=2)] * num_nutations + num_mutations = 10 + mutations = [self._make_mutation(count=2, size=2)] * num_mutations async with self._make_one(flush_interval=0.05) as instance: instance._table.default_operation_timeout = 10 @@ -879,69 +915,65 @@ async def test_timer_flush_end_to_end(self): instance._table.client._gapic_client, "mutate_rows" ) as gapic_mock: gapic_mock.side_effect = ( - lambda *args, **kwargs: self._mock_gapic_return(num_nutations) + lambda *args, **kwargs: self._mock_gapic_return(num_mutations) ) for m in mutations: await instance.append(m) assert instance._entries_processed_since_last_raise == 0 # let flush trigger due to timer - await asyncio.sleep(0.1) - assert instance._entries_processed_since_last_raise == num_nutations - - @pytest.mark.asyncio - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync", - ) - async def test__execute_mutate_rows(self, mutate_rows): - mutate_rows.return_value = AsyncMock() - start_operation = mutate_rows().start - table = mock.Mock() - table.table_name = "test-table" - table.app_profile_id = "test-app-profile" - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - async with self._make_one(table) as instance: - batch = [_make_mutation()] - result = await instance._execute_mutate_rows(batch) - assert start_operation.call_count == 1 - args, kwargs = mutate_rows.call_args - assert args[0] == table.client._gapic_client - assert args[1] == table - assert args[2] == batch - kwargs["operation_timeout"] == 17 - kwargs["attempt_timeout"] == 13 - assert result == [] - - @pytest.mark.asyncio - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync.start" - ) - async def test__execute_mutate_rows_returns_errors(self, mutate_rows): + await CrossSync.sleep(0.1) + assert instance._entries_processed_since_last_raise == num_mutations + + @CrossSync.pytest + async def test__execute_mutate_rows(self): + with mock.patch.object(CrossSync, "_MutateRowsOperation") as mutate_rows: + mutate_rows.return_value = CrossSync.Mock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + async with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = await instance._execute_mutate_rows(batch) + assert start_operation.call_count == 1 + args, kwargs = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + kwargs["operation_timeout"] == 17 + kwargs["attempt_timeout"] == 13 + assert result == [] + + @CrossSync.pytest + async def test__execute_mutate_rows_returns_errors(self): """Errors from operation should be retruned as list""" from google.cloud.bigtable.data.exceptions import ( MutationsExceptionGroup, FailedMutationEntryError, ) - err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) - err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) - mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) - table = mock.Mock() - table.default_mutate_rows_operation_timeout = 17 - table.default_mutate_rows_attempt_timeout = 13 - table.default_mutate_rows_retryable_errors = () - async with self._make_one(table) as instance: - batch = [_make_mutation()] - result = await instance._execute_mutate_rows(batch) - assert len(result) == 2 - assert result[0] == err1 - assert result[1] == err2 - # indices should be set to None - assert result[0].index is None - assert result[1].index is None - - @pytest.mark.asyncio + with mock.patch.object(CrossSync._MutateRowsOperation, "start") as mutate_rows: + err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) + err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + async with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = await instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + # indices should be set to None + assert result[0].index is None + assert result[1].index is None + + @CrossSync.pytest async def test__raise_exceptions(self): """Raise exceptions and reset error state""" from google.cloud.bigtable.data import exceptions @@ -961,13 +993,19 @@ async def test__raise_exceptions(self): # try calling again instance._raise_exceptions() - @pytest.mark.asyncio + @CrossSync.pytest + @CrossSync.convert( + sync_name="test___enter__", replace_symbols={"__aenter__": "__enter__"} + ) async def test___aenter__(self): """Should return self""" async with self._make_one() as instance: assert await instance.__aenter__() == instance - @pytest.mark.asyncio + @CrossSync.pytest + @CrossSync.convert( + sync_name="test___exit__", replace_symbols={"__aexit__": "__exit__"} + ) async def test___aexit__(self): """aexit should call close""" async with self._make_one() as instance: @@ -975,7 +1013,7 @@ async def test___aexit__(self): await instance.__aexit__(None, None, None) assert close_mock.call_count == 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_close(self): """Should clean up all resources""" async with self._make_one() as instance: @@ -988,7 +1026,7 @@ async def test_close(self): assert flush_mock.call_count == 1 assert raise_mock.call_count == 1 - @pytest.mark.asyncio + @CrossSync.pytest async def test_close_w_exceptions(self): """Raise exceptions on close""" from google.cloud.bigtable.data import exceptions @@ -1007,7 +1045,7 @@ async def test_close_w_exceptions(self): # clear out exceptions instance._oldest_exceptions, instance._newest_exceptions = ([], []) - @pytest.mark.asyncio + @CrossSync.pytest async def test__on_exit(self, recwarn): """Should raise warnings if unflushed mutations exist""" async with self._make_one() as instance: @@ -1023,13 +1061,13 @@ async def test__on_exit(self, recwarn): assert "unflushed mutations" in str(w[0].message).lower() assert str(num_left) in str(w[0].message) # calling while closed is noop - instance.closed = True + instance._closed.set() instance._on_exit() assert len(recwarn) == 0 # reset staged mutations for cleanup instance._staged_entries = [] - @pytest.mark.asyncio + @CrossSync.pytest async def test_atexit_registration(self): """Should run _on_exit on program termination""" import atexit @@ -1039,30 +1077,29 @@ async def test_atexit_registration(self): async with self._make_one(): assert register_mock.call_count == 1 - @pytest.mark.asyncio - @mock.patch( - "google.cloud.bigtable.data._async.mutations_batcher._MutateRowsOperationAsync", - ) - async def test_timeout_args_passed(self, mutate_rows): + @CrossSync.pytest + async def test_timeout_args_passed(self): """ batch_operation_timeout and batch_attempt_timeout should be used in api calls """ - mutate_rows.return_value = AsyncMock() - expected_operation_timeout = 17 - expected_attempt_timeout = 13 - async with self._make_one( - batch_operation_timeout=expected_operation_timeout, - batch_attempt_timeout=expected_attempt_timeout, - ) as instance: - assert instance._operation_timeout == expected_operation_timeout - assert instance._attempt_timeout == expected_attempt_timeout - # make simulated gapic call - await instance._execute_mutate_rows([_make_mutation()]) - assert mutate_rows.call_count == 1 - kwargs = mutate_rows.call_args[1] - assert kwargs["operation_timeout"] == expected_operation_timeout - assert kwargs["attempt_timeout"] == expected_attempt_timeout + with mock.patch.object( + CrossSync, "_MutateRowsOperation", return_value=CrossSync.Mock() + ) as mutate_rows: + expected_operation_timeout = 17 + expected_attempt_timeout = 13 + async with self._make_one( + batch_operation_timeout=expected_operation_timeout, + batch_attempt_timeout=expected_attempt_timeout, + ) as instance: + assert instance._operation_timeout == expected_operation_timeout + assert instance._attempt_timeout == expected_attempt_timeout + # make simulated gapic call + await instance._execute_mutate_rows([self._make_mutation()]) + assert mutate_rows.call_count == 1 + kwargs = mutate_rows.call_args[1] + assert kwargs["operation_timeout"] == expected_operation_timeout + assert kwargs["attempt_timeout"] == expected_attempt_timeout @pytest.mark.parametrize( "limit,in_e,start_e,end_e", @@ -1123,7 +1160,7 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): for i in range(1, newest_list_diff + 1): assert mock_batcher._newest_exceptions[-i] == input_list[-i] - @pytest.mark.asyncio + @CrossSync.pytest # test different inputs for retryable exceptions @pytest.mark.parametrize( "input_retryables,expected_retryables", @@ -1148,6 +1185,7 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): ([4], [core_exceptions.DeadlineExceeded]), ], ) + @CrossSync.convert async def test_customizable_retryable_errors( self, input_retryables, expected_retryables ): @@ -1155,25 +1193,21 @@ async def test_customizable_retryable_errors( Test that retryable functions support user-configurable arguments, and that the configured retryables are passed down to the gapic layer. """ - from google.cloud.bigtable.data._async.client import TableAsync - - with mock.patch( - "google.api_core.retry.if_exception_type" + with mock.patch.object( + google.api_core.retry, "if_exception_type" ) as predicate_builder_mock: - with mock.patch( - "google.api_core.retry.retry_target_async" - ) as retry_fn_mock: + with mock.patch.object(CrossSync, "retry_target") as retry_fn_mock: table = None with mock.patch("asyncio.create_task"): - table = TableAsync(mock.Mock(), "instance", "table") + table = CrossSync.Table(mock.Mock(), "instance", "table") async with self._make_one( table, batch_retryable_errors=input_retryables ) as instance: assert instance._retryable_errors == expected_retryables - expected_predicate = lambda a: a in expected_retryables # noqa + expected_predicate = expected_retryables.__contains__ predicate_builder_mock.return_value = expected_predicate retry_fn_mock.side_effect = RuntimeError("stop early") - mutation = _make_mutation(count=1, size=1) + mutation = self._make_mutation(count=1, size=1) await instance._execute_mutate_rows([mutation]) # passed in errors should be used to build the predicate predicate_builder_mock.assert_called_once_with( diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py new file mode 100644 index 000000000..b30f7544f --- /dev/null +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -0,0 +1,351 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import warnings +import pytest +import mock + +from itertools import zip_longest + +from google.cloud.bigtable_v2 import ReadRowsResponse + +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.row import Row + +from ...v2_client.test_row_merger import ReadRowsTest, TestFile + +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + + +@CrossSync.export_sync( + path="tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", +) +class TestReadRowsAcceptanceAsync: + @staticmethod + @CrossSync.convert + def _get_operation_class(): + return CrossSync._ReadRowsOperation + + @staticmethod + @CrossSync.convert + def _get_client_class(): + return CrossSync.DataClient + + def parse_readrows_acceptance_tests(): + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "../read-rows-acceptance-test.json") + + with open(filename) as json_file: + test_json = TestFile.from_json(json_file.read()) + return test_json.read_rows_tests + + @staticmethod + def extract_results_from_row(row: Row): + results = [] + for family, col, cells in row.items(): + for cell in cells: + results.append( + ReadRowsTest.Result( + row_key=row.row_key, + family_name=family, + qualifier=col, + timestamp_micros=cell.timestamp_ns // 1000, + value=cell.value, + label=(cell.labels[0] if cell.labels else ""), + ) + ) + return results + + @staticmethod + @CrossSync.convert + async def _coro_wrapper(stream): + return stream + + @CrossSync.convert + async def _process_chunks(self, *chunks): + async def _row_stream(): + yield ReadRowsResponse(chunks=chunks) + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + results = [] + async for row in merger: + results.append(row) + return results + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + @CrossSync.pytest + async def test_row_merger_scenario(self, test_case: ReadRowsTest): + async def _scenerio_stream(): + for chunk in test_case.chunks: + yield ReadRowsResponse(chunks=[chunk]) + + try: + results = [] + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_scenerio_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + async for row in merger: + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + @CrossSync.pytest + async def test_read_rows_scenario(self, test_case: ReadRowsTest): + async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list): + self.chunk_list = chunk_list + self.idx = -1 + + def __aiter__(self): + return self + + def __iter__(self): + return self + + async def __anext__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + chunk = self.chunk_list[self.idx] + return ReadRowsResponse(chunks=[chunk]) + raise CrossSync.StopIteration + + def __next__(self): + return self.__anext__() + + def cancel(self): + pass + + return mock_stream(chunk_list) + + with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # use emulator mode to avoid auth issues in CI + client = self._get_client_class()() + try: + table = client.get_table("instance", "table") + results = [] + with mock.patch.object( + table.client._gapic_client, "read_rows" + ) as read_rows: + # run once, then return error on retry + read_rows.return_value = _make_gapic_stream(test_case.chunks) + async for row in await table.read_rows_stream(query={}): + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + finally: + await client.close() + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + @CrossSync.pytest + async def test_out_of_order_rows(self): + async def _row_stream(): + yield ReadRowsResponse(last_scanned_row_key=b"a") + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = b"b" + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + with pytest.raises(InvalidChunk): + async for _ in merger: + pass + + @CrossSync.pytest + async def test_bare_reset(self): + first_chunk = ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk( + row_key=b"a", family_name="f", qualifier=b"q", value=b"v" + ) + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, family_name="f") + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) + ), + ) + with pytest.raises(InvalidChunk): + await self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, value=b"v") + ), + ) + + @CrossSync.pytest + async def test_missing_family(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + qualifier=b"q", + timestamp_micros=1000, + value=b"v", + commit_row=True, + ) + ) + + @CrossSync.pytest + async def test_mid_cell_row_key_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), + ) + + @CrossSync.pytest + async def test_mid_cell_family_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + family_name="f2", value=b"v", commit_row=True + ), + ) + + @CrossSync.pytest + async def test_mid_cell_qualifier_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + qualifier=b"q2", value=b"v", commit_row=True + ), + ) + + @CrossSync.pytest + async def test_mid_cell_timestamp_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + timestamp_micros=2000, value=b"v", commit_row=True + ), + ) + + @CrossSync.pytest + async def test_mid_cell_labels_change(self): + with pytest.raises(InvalidChunk): + await self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), + ) diff --git a/tests/unit/data/test_read_rows_acceptance.py b/tests/unit/data/test_read_rows_acceptance.py deleted file mode 100644 index 7cb3c08dc..000000000 --- a/tests/unit/data/test_read_rows_acceptance.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import os -from itertools import zip_longest - -import pytest -import mock - -from google.cloud.bigtable_v2 import ReadRowsResponse - -from google.cloud.bigtable.data._async.client import BigtableDataClientAsync -from google.cloud.bigtable.data.exceptions import InvalidChunk -from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync -from google.cloud.bigtable.data.row import Row - -from ..v2_client.test_row_merger import ReadRowsTest, TestFile - - -def parse_readrows_acceptance_tests(): - dirname = os.path.dirname(__file__) - filename = os.path.join(dirname, "./read-rows-acceptance-test.json") - - with open(filename) as json_file: - test_json = TestFile.from_json(json_file.read()) - return test_json.read_rows_tests - - -def extract_results_from_row(row: Row): - results = [] - for family, col, cells in row.items(): - for cell in cells: - results.append( - ReadRowsTest.Result( - row_key=row.row_key, - family_name=family, - qualifier=col, - timestamp_micros=cell.timestamp_ns // 1000, - value=cell.value, - label=(cell.labels[0] if cell.labels else ""), - ) - ) - return results - - -@pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description -) -@pytest.mark.asyncio -async def test_row_merger_scenario(test_case: ReadRowsTest): - async def _scenerio_stream(): - for chunk in test_case.chunks: - yield ReadRowsResponse(chunks=[chunk]) - - try: - results = [] - instance = mock.Mock() - instance._last_yielded_row_key = None - instance._remaining_count = None - chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_scenerio_stream()) - ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) - async for row in merger: - for cell in row: - cell_result = ReadRowsTest.Result( - row_key=cell.row_key, - family_name=cell.family, - qualifier=cell.qualifier, - timestamp_micros=cell.timestamp_micros, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - -@pytest.mark.parametrize( - "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description -) -@pytest.mark.asyncio -async def test_read_rows_scenario(test_case: ReadRowsTest): - async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): - from google.cloud.bigtable_v2 import ReadRowsResponse - - class mock_stream: - def __init__(self, chunk_list): - self.chunk_list = chunk_list - self.idx = -1 - - def __aiter__(self): - return self - - async def __anext__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - chunk = self.chunk_list[self.idx] - return ReadRowsResponse(chunks=[chunk]) - raise StopAsyncIteration - - def cancel(self): - pass - - return mock_stream(chunk_list) - - try: - with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): - # use emulator mode to avoid auth issues in CI - client = BigtableDataClientAsync() - table = client.get_table("instance", "table") - results = [] - with mock.patch.object(table.client._gapic_client, "read_rows") as read_rows: - # run once, then return error on retry - read_rows.return_value = _make_gapic_stream(test_case.chunks) - async for row in await table.read_rows_stream(query={}): - for cell in row: - cell_result = ReadRowsTest.Result( - row_key=cell.row_key, - family_name=cell.family, - qualifier=cell.qualifier, - timestamp_micros=cell.timestamp_micros, - value=cell.value, - label=cell.labels[0] if cell.labels else "", - ) - results.append(cell_result) - except InvalidChunk: - results.append(ReadRowsTest.Result(error=True)) - finally: - await client.close() - for expected, actual in zip_longest(test_case.results, results): - assert actual == expected - - -@pytest.mark.asyncio -async def test_out_of_order_rows(): - async def _row_stream(): - yield ReadRowsResponse(last_scanned_row_key=b"a") - - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = b"b" - chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_row_stream()) - ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) - with pytest.raises(InvalidChunk): - async for _ in merger: - pass - - -@pytest.mark.asyncio -async def test_bare_reset(): - first_chunk = ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk( - row_key=b"a", family_name="f", qualifier=b"q", value=b"v" - ) - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, family_name="f") - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) - ), - ) - with pytest.raises(InvalidChunk): - await _process_chunks( - first_chunk, - ReadRowsResponse.CellChunk( - ReadRowsResponse.CellChunk(reset_row=True, value=b"v") - ), - ) - - -@pytest.mark.asyncio -async def test_missing_family(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - qualifier=b"q", - timestamp_micros=1000, - value=b"v", - commit_row=True, - ) - ) - - -@pytest.mark.asyncio -async def test_mid_cell_row_key_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), - ) - - -@pytest.mark.asyncio -async def test_mid_cell_family_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(family_name="f2", value=b"v", commit_row=True), - ) - - -@pytest.mark.asyncio -async def test_mid_cell_qualifier_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(qualifier=b"q2", value=b"v", commit_row=True), - ) - - -@pytest.mark.asyncio -async def test_mid_cell_timestamp_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk( - timestamp_micros=2000, value=b"v", commit_row=True - ), - ) - - -@pytest.mark.asyncio -async def test_mid_cell_labels_change(): - with pytest.raises(InvalidChunk): - await _process_chunks( - ReadRowsResponse.CellChunk( - row_key=b"a", - family_name="f", - qualifier=b"q", - timestamp_micros=1000, - value_size=2, - value=b"v", - ), - ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), - ) - - -async def _coro_wrapper(stream): - return stream - - -async def _process_chunks(*chunks): - async def _row_stream(): - yield ReadRowsResponse(chunks=chunks) - - instance = mock.Mock() - instance._remaining_count = None - instance._last_yielded_row_key = None - chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_row_stream()) - ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) - results = [] - async for row in merger: - results.append(row) - return results From 698902d721e7bfa19ddd5692cfa9ef2b6507b9bc Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 18 Jul 2024 12:38:44 -0700 Subject: [PATCH 194/360] added sync pooled transport --- .../bigtable_v2/services/bigtable/client.py | 2 + .../bigtable/transports/pooled_grpc.py | 445 ++++++++++++++++++ 2 files changed, 447 insertions(+) create mode 100644 google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py diff --git a/google/cloud/bigtable_v2/services/bigtable/client.py b/google/cloud/bigtable_v2/services/bigtable/client.py index 7eda705b9..4a380651d 100644 --- a/google/cloud/bigtable_v2/services/bigtable/client.py +++ b/google/cloud/bigtable_v2/services/bigtable/client.py @@ -56,6 +56,7 @@ from .transports.grpc import BigtableGrpcTransport from .transports.grpc_asyncio import BigtableGrpcAsyncIOTransport from .transports.pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport +from .transports.pooled_grpc import PooledBigtableGrpcTransport from .transports.rest import BigtableRestTransport @@ -71,6 +72,7 @@ class BigtableClientMeta(type): _transport_registry["grpc"] = BigtableGrpcTransport _transport_registry["grpc_asyncio"] = BigtableGrpcAsyncIOTransport _transport_registry["pooled_grpc_asyncio"] = PooledBigtableGrpcAsyncIOTransport + _transport_registry["pooled_grpc"] = PooledBigtableGrpcTransport _transport_registry["rest"] = BigtableRestTransport def get_transport_class( diff --git a/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py new file mode 100644 index 000000000..2c808a000 --- /dev/null +++ b/google/cloud/bigtable_v2/services/bigtable/transports/pooled_grpc.py @@ -0,0 +1,445 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from functools import partialmethod +from functools import partial +import time +from typing import ( + Awaitable, + Callable, + Dict, + Optional, + Sequence, + Tuple, + Union, + List, + Type, +) + +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.bigtable_v2.types import bigtable +from .base import BigtableTransport, DEFAULT_CLIENT_INFO +from .grpc import BigtableGrpcTransport + + +class PooledMultiCallable: + def __init__(self, channel_pool: "PooledChannel", *args, **kwargs): + self._init_args = args + self._init_kwargs = kwargs + self.next_channel_fn = channel_pool.next_channel + + def with_call(self, *args, **kwargs): + raise NotImplementedError() + + def future(self, *args, **kwargs): + raise NotImplementedError() + + +class PooledUnaryUnaryMultiCallable(PooledMultiCallable, grpc.UnaryUnaryMultiCallable): + def __call__(self, *args, **kwargs): + return self.next_channel_fn().unary_unary( + *self._init_args, **self._init_kwargs + )(*args, **kwargs) + + +class PooledUnaryStreamMultiCallable( + PooledMultiCallable, grpc.UnaryStreamMultiCallable +): + def __call__(self, *args, **kwargs): + return self.next_channel_fn().unary_stream( + *self._init_args, **self._init_kwargs + )(*args, **kwargs) + + +class PooledStreamUnaryMultiCallable( + PooledMultiCallable, grpc.StreamUnaryMultiCallable +): + def __call__(self, *args, **kwargs): + return self.next_channel_fn().stream_unary( + *self._init_args, **self._init_kwargs + )(*args, **kwargs) + + +class PooledStreamStreamMultiCallable( + PooledMultiCallable, grpc.StreamStreamMultiCallable +): + def __call__(self, *args, **kwargs): + return self.next_channel_fn().stream_stream( + *self._init_args, **self._init_kwargs + )(*args, **kwargs) + + +class PooledChannel(grpc.Channel): + def __init__( + self, + pool_size: int = 3, + host: str = "bigtable.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + quota_project_id: Optional[str] = None, + default_scopes: Optional[Sequence[str]] = None, + scopes: Optional[Sequence[str]] = None, + default_host: Optional[str] = None, + insecure: bool = False, + **kwargs, + ): + self._pool: List[grpc.Channel] = [] + self._next_idx = 0 + if insecure: + self._create_channel = partial(grpc.insecure_channel, host) + else: + self._create_channel = partial( + grpc_helpers.create_channel, + target=host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=default_scopes, + scopes=scopes, + default_host=default_host, + **kwargs, + ) + for i in range(pool_size): + self._pool.append(self._create_channel()) + + def next_channel(self) -> grpc.Channel: + channel = self._pool[self._next_idx] + self._next_idx = (self._next_idx + 1) % len(self._pool) + return channel + + def unary_unary(self, *args, **kwargs) -> grpc.UnaryUnaryMultiCallable: + return PooledUnaryUnaryMultiCallable(self, *args, **kwargs) + + def unary_stream(self, *args, **kwargs) -> grpc.UnaryStreamMultiCallable: + return PooledUnaryStreamMultiCallable(self, *args, **kwargs) + + def stream_unary(self, *args, **kwargs) -> grpc.StreamUnaryMultiCallable: + return PooledStreamUnaryMultiCallable(self, *args, **kwargs) + + def stream_stream(self, *args, **kwargs) -> grpc.StreamStreamMultiCallable: + return PooledStreamStreamMultiCallable(self, *args, **kwargs) + + def close(self): + for channel in self._pool: + channel.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: + raise NotImplementedError() + + def wait_for_state_change(self, last_observed_state): + raise NotImplementedError() + + def subscribe( + self, callback, try_to_connect: bool = False + ) -> grpc.ChannelConnectivity: + raise NotImplementedError() + + def unsubscribe(self, callback): + raise NotImplementedError() + + def replace_channel( + self, channel_idx, grace=1, new_channel=None, event=None + ) -> grpc.Channel: + """ + Replaces a channel in the pool with a fresh one. + + The `new_channel` will start processing new requests immidiately, + but the old channel will continue serving existing clients for + `grace` seconds + + Args: + channel_idx(int): the channel index in the pool to replace + grace(Optional[float]): The time to wait for active RPCs to + finish. If a grace period is not specified (by passing None for + grace), all existing RPCs are cancelled immediately. + new_channel(grpc.Channel): a new channel to insert into the pool + at `channel_idx`. If `None`, a new channel will be created. + event(Optional[threading.Event]): an event to signal when the + replacement should be aborted. If set, will call `event.wait()` + instead of the `time.sleep` function. + """ + if channel_idx >= len(self._pool) or channel_idx < 0: + raise ValueError( + f"invalid channel_idx {channel_idx} for pool size {len(self._pool)}" + ) + if new_channel is None: + new_channel = self._create_channel() + old_channel = self._pool[channel_idx] + self._pool[channel_idx] = new_channel + if event: + event.wait(grace) + else: + time.sleep(grace) + old_channel.close() + return new_channel + + +class PooledBigtableGrpcTransport(BigtableGrpcTransport): + """Pooled gRPC backend transport for Bigtable. + + Service for reading from and writing to existing Bigtable + tables. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + + This class allows channel pooling, so multiple channels can be used concurrently + when making requests. Channels are rotated in a round-robin fashion. + """ + + @classmethod + def with_fixed_size(cls, pool_size) -> Type["PooledBigtableGrpcTransport"]: + """ + Creates a new class with a fixed channel pool size. + + A fixed channel pool makes compatibility with other transports easier, + as the initializer signature is the same. + """ + + class PooledTransportFixed(cls): + __init__ = partialmethod(cls.__init__, pool_size=pool_size) + + PooledTransportFixed.__name__ = f"{cls.__name__}_{pool_size}" + PooledTransportFixed.__qualname__ = PooledTransportFixed.__name__ + return PooledTransportFixed + + @classmethod + def create_channel( + cls, + pool_size: int = 3, + host: str = "bigtable.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a PooledChannel object, representing a pool of gRPC channels + Args: + pool_size (int): The number of channels in the pool. + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + PooledChannel: a channel pool object + """ + + return PooledChannel( + pool_size, + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + pool_size: int = 3, + host: str = "bigtable.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + pool_size (int): the number of grpc channels to maintain in a pool + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + ValueError: if ``pool_size`` <= 0 + """ + if pool_size <= 0: + raise ValueError(f"invalid pool_size: {pool_size}") + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + BigtableTransport.__init__( + self, + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + self._quota_project_id = quota_project_id + self._grpc_channel = type(self).create_channel( + pool_size, + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=self._quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def pool_size(self) -> int: + """The number of grpc channels in the pool.""" + return len(self._grpc_channel._pool) + + @property + def channels(self) -> List[grpc.Channel]: + """Acccess the internal list of grpc channels.""" + return self._grpc_channel._pool + + def replace_channel( + self, channel_idx, grace=1, new_channel=None, event=None + ) -> grpc.Channel: + """ + Replaces a channel in the pool with a fresh one. + + The `new_channel` will start processing new requests immidiately, + but the old channel will continue serving existing clients for `grace` seconds + + Args: + channel_idx(int): the channel index in the pool to replace + grace(Optional[float]): The time to wait for active RPCs to + finish. If a grace period is not specified (by passing None for + grace), all existing RPCs are cancelled immediately. + new_channel(grpc.Channel): a new channel to insert into the pool + at `channel_idx`. If `None`, a new channel will be created. + event(Optional[threading.Event]): an event to signal when the + replacement should be aborted. If set, will call `event.wait()` + instead of the `time.sleep` function. + """ + return self._grpc_channel.replace_channel( + channel_idx=channel_idx, grace=grace, new_channel=new_channel, event=event + ) + + +__all__ = ("PooledBigtableGrpcTransport",) From 64166dbd03008652f665f8db2f255840ce2c52cc Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 18 Jul 2024 12:39:15 -0700 Subject: [PATCH 195/360] only strip CrossSync.rm_aio on Call visits --- .cross_sync/transformers.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 6b48421d3..cc8a18763 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -148,18 +148,15 @@ def visit_AsyncWith(self, node): Async with statements are not fully wrapped by calls """ found_rmaio = False - new_items = [] for item in node.items: if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute) and isinstance(item.context_expr.func.value, ast.Name) and \ item.context_expr.func.attr == "rm_aio" and "CrossSync" in item.context_expr.func.value.id: found_rmaio = True - new_items.append(item.context_expr.args[0]) - else: - new_items.append(item) + break if found_rmaio: new_node = ast.copy_location( ast.With( - [self.generic_visit(item) for item in new_items], + [self.generic_visit(item) for item in node.items], [self.generic_visit(stmt) for stmt in node.body], ), node, @@ -177,7 +174,7 @@ def visit_AsyncFor(self, node): return ast.copy_location( ast.For( self.visit(node.target), - self.visit(node.iter.args[0]), + self.visit(it), [self.visit(stmt) for stmt in node.body], [self.visit(stmt) for stmt in node.orelse], ), From 73ba23b7fc4d714446a462dc325fa9b641e4e102 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 18 Jul 2024 12:39:56 -0700 Subject: [PATCH 196/360] support decorators in nested functions --- .cross_sync/transformers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index cc8a18763..f630e29eb 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -201,6 +201,8 @@ def visit_AsyncFunctionDef(self, node): node = handler.sync_ast_transform(node, globals()) if node is None: return None + # recurse to any nested functions + node = self.generic_visit(node) except ValueError: # keep unknown decorators node.decorator_list.append(decorator) From 12e88606940a000cab2c4f33f028ccadb0ce4bee Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 13:37:24 -0600 Subject: [PATCH 197/360] added missing convert annotations --- google/cloud/bigtable/data/_async/client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index f18b46256..6428dfa0c 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -830,6 +830,7 @@ async def read_rows_sharded( # limit the number of concurrent requests using a semaphore concurrency_sem = CrossSync.Semaphore(_CONCURRENCY_LIMIT) + @CrossSync.convert async def read_rows_with_semaphore(query): async with CrossSync.rm_aio(concurrency_sem): # calculate new timeout based on time left in overall operation @@ -989,6 +990,7 @@ async def sample_row_keys( # prepare request metadata = _make_metadata(self.table_name, self.app_profile_id) + @CrossSync.convert async def execute_rpc(): results = CrossSync.rm_aio( await self.client._gapic_client.sample_row_keys( From 63a528a979c7cbde0dd281c3105158b6a9b9c315 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 13:38:07 -0600 Subject: [PATCH 198/360] added import blocks for sync versions --- .../bigtable/data/_async/_mutate_rows.py | 3 +++ google/cloud/bigtable/data/_async/client.py | 22 +++++++++++++++++++ .../bigtable/data/_async/mutations_batcher.py | 2 ++ tests/unit/data/_async/test__mutate_rows.py | 4 ++++ tests/unit/data/_async/test_client.py | 7 ++++++ .../data/_async/test_mutations_batcher.py | 4 ++++ .../data/_async/test_read_rows_acceptance.py | 6 +++++ 7 files changed, 48 insertions(+) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index e62d43397..a3f530fd2 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -39,6 +39,9 @@ ) CrossSync.add_mapping("GapicClient", BigtableAsyncClient) + else: + from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient + CrossSync.add_mapping("GapicClient", BigtableClient) @CrossSync.export_sync( diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 6428dfa0c..e6b254ee6 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -98,7 +98,29 @@ CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) CrossSync.add_mapping("MutationsBatcher", MutationsBatcherAsync) +else: + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledBigtableGrpcTransport, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledChannel + ) + from google.cloud.bigtable_v2.services.bigtable.client import ( + BigtableClient, + ) + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation + from google.cloud.bigtable.data._sync.mutations_batcher import ( + MutationsBatcher, + ) + # define file-specific cross-sync replacements + CrossSync.add_mapping("GapicClient", BigtableClient) + CrossSync.add_mapping("PooledTransport", PooledBigtableGrpcTransport) + CrossSync.add_mapping("PooledChannel", PooledChannel) + CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperation) + CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperation) + CrossSync.add_mapping("MutationsBatcher", MutationsBatcher) if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 7a6def9e4..253bfea4a 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -39,6 +39,8 @@ if CrossSync.is_async: from google.cloud.bigtable.data._async.client import TableAsync + else: + from google.cloud.bigtable.data._sync.client import Table @CrossSync.export_sync( diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index a307a7008..21e5464ce 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -27,6 +27,10 @@ except ImportError: # pragma: NO COVER import mock # type: ignore +if not CrossSync.is_async: + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation + CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperation) + @CrossSync.export_sync( path="tests.unit.data._sync.test__mutate_rows.TestMutateRowsOperation", diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index b51987c5d..d689ee2d5 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -45,6 +45,13 @@ from google.cloud.bigtable.data._async.client import TableAsync CrossSync.add_mapping("grpc_helpers", grpc_helpers_async) +else: + from google.api_core import grpc_helpers + from google.cloud.bigtable.data._sync.client import Table + from google.cloud.bigtable.data._sync.client import BigtableDataClient + + CrossSync.add_mapping("grpc_helpers", grpc_helpers) + CrossSync.add_mapping("DataClient", BigtableDataClient) @CrossSync.export_sync( diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index fcd425273..37298d552 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -28,6 +28,10 @@ except ImportError: # pragma: NO COVER import mock # type: ignore +if not CrossSync.is_async: + from google.cloud.bigtable.data._sync.client import Table + CrossSync.add_mapping("Table", Table) + @CrossSync.export_sync( path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl" diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index b30f7544f..2b5a25a74 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -29,6 +29,12 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if not CrossSync.is_async: + from google.cloud.bigtable.data._sync.client import BigtableDataClient + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + CrossSync.add_mapping("DataClient", BigtableDataClient) + CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperation) + @CrossSync.export_sync( path="tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", From f5dcdf5cdb360914cdecadbe90baacacc15238ec Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 13:55:36 -0600 Subject: [PATCH 199/360] added annotations to tests --- tests/system/data/test_system_async.py | 25 ++++++++++--------- tests/unit/data/_async/test_client.py | 15 +++++------ .../data/_async/test_mutations_batcher.py | 1 + .../data/_async/test_read_rows_acceptance.py | 5 ++-- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index d12936305..0a8aba471 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -61,7 +61,7 @@ async def add_row( } ], } - await self.table.client._gapic_client.mutate_row(request) + CrossSync.rm_aio(await self.table.client._gapic_client.mutate_row(request)) self.rows.append(row_key) @CrossSync.convert @@ -74,7 +74,7 @@ async def delete_rows(self): for row in self.rows ], } - await self.table.client._gapic_client.mutate_rows(request) + CrossSync.rm_aio(await self.table.client._gapic_client.mutate_rows(request)) @CrossSync.export_sync(path="tests.system.data.test_system.TestSystem") @@ -83,18 +83,19 @@ class TestSystemAsync: @CrossSync.pytest_fixture(scope="session") async def client(self): project = os.getenv("GOOGLE_CLOUD_PROJECT") or None - async with CrossSync.DataClient(project=project, pool_size=4) as client: + async with CrossSync.rm_aio(CrossSync.DataClient(project=project, pool_size=4)) as client: yield client @CrossSync.convert @CrossSync.pytest_fixture(scope="session") async def table(self, client, table_id, instance_id): - async with client.get_table( + async with CrossSync.rm_aio(client.get_table( instance_id, table_id, - ) as table: + )) as table: yield table + @CrossSync.drop_method @pytest.fixture(scope="session") def event_loop(self): loop = asyncio.get_event_loop() @@ -141,7 +142,7 @@ async def _retrieve_cell_value(self, table, row_key): """ from google.cloud.bigtable.data import ReadRowsQuery - row_list = await table.read_rows(ReadRowsQuery(row_keys=row_key)) + row_list = CrossSync.rm_aio(await table.read_rows(ReadRowsQuery(row_keys=row_key))) assert len(row_list) == 1 row = row_list[0] cell = row.cells[0] @@ -159,11 +160,11 @@ async def _create_row_and_mutation( row_key = uuid.uuid4().hex.encode() family = TEST_FAMILY qualifier = b"test-qualifier" - await temp_rows.add_row( + CrossSync.rm_aio(await temp_rows.add_row( row_key, family=family, qualifier=qualifier, value=start_value - ) + )) # ensure cell is initialized - assert (await self._retrieve_cell_value(table, row_key)) == start_value + assert CrossSync.rm_aio(await self._retrieve_cell_value(table, row_key)) == start_value mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) return row_key, mutation @@ -173,7 +174,7 @@ async def _create_row_and_mutation( async def temp_rows(self, table): builder = CrossSync.TempRowBuilder(table) yield builder - await builder.delete_rows() + CrossSync.rm_aio(await builder.delete_rows()) @pytest.mark.usefixtures("table") @pytest.mark.usefixtures("client") @@ -219,9 +220,9 @@ async def test_mutation_set_cell(self, table, temp_rows): """ row_key = b"bulk_mutate" new_value = uuid.uuid4().hex.encode() - row_key, mutation = await self._create_row_and_mutation( + row_key, mutation = CrossSync.rm_aio(await self._create_row_and_mutation( table, temp_rows, new_value=new_value - ) + )) await table.mutate_row(row_key, mutation) # ensure cell is updated diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index d689ee2d5..7218fe828 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1508,17 +1508,16 @@ def __init__(self, chunk_list, sleep_time): self.idx = -1 self.sleep_time = sleep_time + @CrossSync.convert(sync_name="__iter__") def __aiter__(self): return self - def __iter__(self): - return self - + @CrossSync.convert(sync_name="__next__") async def __anext__(self): self.idx += 1 if len(self.chunk_list) > self.idx: if sleep_time: - await CrossSync.sleep(self.sleep_time) + CrossSync.rm_aio(await CrossSync.sleep(self.sleep_time)) chunk = self.chunk_list[self.idx] if isinstance(chunk, Exception): raise chunk @@ -1526,9 +1525,6 @@ async def __anext__(self): return ReadRowsResponse(chunks=[chunk]) raise CrossSync.StopIteration - def __next__(self): - return self.__anext__() - def cancel(self): pass @@ -1536,7 +1532,7 @@ def cancel(self): @CrossSync.convert async def execute_fn(self, table, *args, **kwargs): - return await table.read_rows(*args, **kwargs) + return CrossSync.rm_aio(await table.read_rows(*args, **kwargs)) @CrossSync.pytest async def test_read_rows(self): @@ -2032,7 +2028,7 @@ async def test_read_rows_sharded_concurrent(self): import time async def mock_call(*args, **kwargs): - await asyncio.sleep(0.1) + await CrossSync.sleep(0.1) return [mock.Mock()] async with self._make_client() as client: @@ -2527,6 +2523,7 @@ async def _mock_response(self, response_list): for i in range(len(response_list)) ] + @CrossSync.convert async def generator(): yield MutateRowsResponse(entries=entries) diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 37298d552..ec407ca79 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -897,6 +897,7 @@ async def _mock_gapic_return(self, num=5): from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 + @CrossSync.convert async def gen(num): for i in range(num): entry = MutateRowsResponse.Entry( diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 2b5a25a74..fc6cfdf46 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -82,6 +82,7 @@ async def _coro_wrapper(stream): @CrossSync.convert async def _process_chunks(self, *chunks): + @CrossSync.convert async def _row_stream(): yield ReadRowsResponse(chunks=chunks) @@ -93,7 +94,7 @@ async def _row_stream(): ) merger = self._get_operation_class().merge_rows(chunker) results = [] - async for row in merger: + async for row in CrossSync.rm_aio(merger): results.append(row) return results @@ -115,7 +116,7 @@ async def _scenerio_stream(): instance, self._coro_wrapper(_scenerio_stream()) ) merger = self._get_operation_class().merge_rows(chunker) - async for row in merger: + async for row in CrossSync.rm_aio(merger): for cell in row: cell_result = ReadRowsTest.Result( row_key=cell.row_key, From be407714433928a0b7ecc02318648808f53e53ce Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 13:55:52 -0600 Subject: [PATCH 200/360] removed unneeded test imports --- tests/unit/data/_async/test__mutate_rows.py | 4 ---- tests/unit/data/_async/test_client.py | 2 -- tests/unit/data/_async/test_mutations_batcher.py | 4 ---- tests/unit/data/_async/test_read_rows_acceptance.py | 6 ------ 4 files changed, 16 deletions(-) diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 21e5464ce..a307a7008 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -27,10 +27,6 @@ except ImportError: # pragma: NO COVER import mock # type: ignore -if not CrossSync.is_async: - from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation - CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperation) - @CrossSync.export_sync( path="tests.unit.data._sync.test__mutate_rows.TestMutateRowsOperation", diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 7218fe828..60f555248 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -48,10 +48,8 @@ else: from google.api_core import grpc_helpers from google.cloud.bigtable.data._sync.client import Table - from google.cloud.bigtable.data._sync.client import BigtableDataClient CrossSync.add_mapping("grpc_helpers", grpc_helpers) - CrossSync.add_mapping("DataClient", BigtableDataClient) @CrossSync.export_sync( diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index ec407ca79..2c61d005a 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -28,10 +28,6 @@ except ImportError: # pragma: NO COVER import mock # type: ignore -if not CrossSync.is_async: - from google.cloud.bigtable.data._sync.client import Table - CrossSync.add_mapping("Table", Table) - @CrossSync.export_sync( path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl" diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index fc6cfdf46..0bd5d82f8 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -29,12 +29,6 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -if not CrossSync.is_async: - from google.cloud.bigtable.data._sync.client import BigtableDataClient - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - CrossSync.add_mapping("DataClient", BigtableDataClient) - CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperation) - @CrossSync.export_sync( path="tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", From a676d496d8de2a394a35356b0ea54f30716c6372 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 14:04:31 -0600 Subject: [PATCH 201/360] ran blacken --- .../bigtable/data/_async/_mutate_rows.py | 1 + google/cloud/bigtable/data/_async/client.py | 2 +- tests/system/data/test_system_async.py | 37 ++++++++++++------- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index a3f530fd2..f1c016e4c 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -41,6 +41,7 @@ CrossSync.add_mapping("GapicClient", BigtableAsyncClient) else: from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient + CrossSync.add_mapping("GapicClient", BigtableClient) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index e6b254ee6..bc389f2d4 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -103,7 +103,7 @@ PooledBigtableGrpcTransport, ) from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledChannel + PooledChannel, ) from google.cloud.bigtable_v2.services.bigtable.client import ( BigtableClient, diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 0a8aba471..28a89a8e2 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -83,16 +83,20 @@ class TestSystemAsync: @CrossSync.pytest_fixture(scope="session") async def client(self): project = os.getenv("GOOGLE_CLOUD_PROJECT") or None - async with CrossSync.rm_aio(CrossSync.DataClient(project=project, pool_size=4)) as client: + async with CrossSync.rm_aio( + CrossSync.DataClient(project=project, pool_size=4) + ) as client: yield client @CrossSync.convert @CrossSync.pytest_fixture(scope="session") async def table(self, client, table_id, instance_id): - async with CrossSync.rm_aio(client.get_table( - instance_id, - table_id, - )) as table: + async with CrossSync.rm_aio( + client.get_table( + instance_id, + table_id, + ) + ) as table: yield table @CrossSync.drop_method @@ -142,7 +146,9 @@ async def _retrieve_cell_value(self, table, row_key): """ from google.cloud.bigtable.data import ReadRowsQuery - row_list = CrossSync.rm_aio(await table.read_rows(ReadRowsQuery(row_keys=row_key))) + row_list = CrossSync.rm_aio( + await table.read_rows(ReadRowsQuery(row_keys=row_key)) + ) assert len(row_list) == 1 row = row_list[0] cell = row.cells[0] @@ -160,11 +166,16 @@ async def _create_row_and_mutation( row_key = uuid.uuid4().hex.encode() family = TEST_FAMILY qualifier = b"test-qualifier" - CrossSync.rm_aio(await temp_rows.add_row( - row_key, family=family, qualifier=qualifier, value=start_value - )) + CrossSync.rm_aio( + await temp_rows.add_row( + row_key, family=family, qualifier=qualifier, value=start_value + ) + ) # ensure cell is initialized - assert CrossSync.rm_aio(await self._retrieve_cell_value(table, row_key)) == start_value + assert ( + CrossSync.rm_aio(await self._retrieve_cell_value(table, row_key)) + == start_value + ) mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) return row_key, mutation @@ -220,9 +231,9 @@ async def test_mutation_set_cell(self, table, temp_rows): """ row_key = b"bulk_mutate" new_value = uuid.uuid4().hex.encode() - row_key, mutation = CrossSync.rm_aio(await self._create_row_and_mutation( - table, temp_rows, new_value=new_value - )) + row_key, mutation = CrossSync.rm_aio( + await self._create_row_and_mutation(table, temp_rows, new_value=new_value) + ) await table.mutate_row(row_key, mutation) # ensure cell is updated From c63d88ea2fc9e47dd2d2c42c93e1f315fe1b1788 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 14:12:21 -0600 Subject: [PATCH 202/360] fixed lint issues --- google/cloud/bigtable/data/_sync/cross_sync.py | 1 - .../cloud/bigtable/data/_sync/cross_sync_decorators.py | 9 ++++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync.py index 0af5f0c4a..70a3950d0 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync.py @@ -36,7 +36,6 @@ import threading import time from .cross_sync_decorators import ( - AstDecorator, ExportSync, Convert, DropMethod, diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py index 76350f443..2ca763d95 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -136,12 +136,14 @@ def get_for_node(cls, node: ast.Call) -> "AstDecorator": raise ValueError("Not a CrossSync decorator") @classmethod - def _convert_ast_to_py(cls, ast_node: ast.expr) -> Any: + def _convert_ast_to_py(cls, ast_node: ast.expr|None) -> Any: """ Helper to convert ast primitives to python primitives. Used when unwrapping arguments """ import ast + if ast_node is None: + return None if isinstance(ast_node, ast.Constant): return ast_node.value if isinstance(ast_node, ast.List): @@ -213,7 +215,6 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): wrapped_node = copy.deepcopy(wrapped_node) # update name sync_cls_name = self.path.rsplit(".", 1)[-1] - orig_name = wrapped_node.name wrapped_node.name = sync_cls_name # strip CrossSync decorators if hasattr(wrapped_node, "decorator_list"): @@ -238,7 +239,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): # convert class contents wrapped_node = transformers_globals["RmAioFunctions"]().visit(wrapped_node) replace_dict = self.replace_symbols or {} - replace_dict.update({"CrossSync": f"CrossSync._Sync_Impl"}) + replace_dict.update({"CrossSync": "CrossSync._Sync_Impl"}) wrapped_node = transformers_globals["SymbolReplacer"](replace_dict).visit( wrapped_node ) @@ -324,8 +325,6 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): """ convert async to sync """ - import ast - converted = transformers_globals["AsyncToSync"]().visit(wrapped_node) return converted From 5d11fc0e3a20b930fd3d250849716d56df1e49d7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 15:03:25 -0600 Subject: [PATCH 203/360] added generated files --- .../cloud/bigtable/data/_sync/_mutate_rows.py | 190 ++ .../cloud/bigtable/data/_sync/_read_rows.py | 305 ++ google/cloud/bigtable/data/_sync/client.py | 1197 +++++++ .../bigtable/data/_sync/mutations_batcher.py | 449 +++ tests/system/data/test_system.py | 808 +++++ tests/unit/data/_sync/__init__.py | 0 tests/unit/data/_sync/test__mutate_rows.py | 310 ++ tests/unit/data/_sync/test__read_rows.py | 357 +++ tests/unit/data/_sync/test_client.py | 2744 +++++++++++++++++ .../unit/data/_sync/test_mutations_batcher.py | 1081 +++++++ .../data/_sync/test_read_rows_acceptance.py | 327 ++ 11 files changed, 7768 insertions(+) create mode 100644 google/cloud/bigtable/data/_sync/_mutate_rows.py create mode 100644 google/cloud/bigtable/data/_sync/_read_rows.py create mode 100644 google/cloud/bigtable/data/_sync/client.py create mode 100644 google/cloud/bigtable/data/_sync/mutations_batcher.py create mode 100644 tests/system/data/test_system.py create mode 100644 tests/unit/data/_sync/__init__.py create mode 100644 tests/unit/data/_sync/test__mutate_rows.py create mode 100644 tests/unit/data/_sync/test__read_rows.py create mode 100644 tests/unit/data/_sync/test_client.py create mode 100644 tests/unit/data/_sync/test_mutations_batcher.py create mode 100644 tests/unit/data/_sync/test_read_rows_acceptance.py diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py new file mode 100644 index 000000000..8e01f95e8 --- /dev/null +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -0,0 +1,190 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by CrossSync. Do not edit manually. +from __future__ import annotations +from typing import Sequence, TYPE_CHECKING +import functools +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +import google.cloud.bigtable.data.exceptions as bt_exceptions +from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable.data.mutations import _EntryWithProto +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +if TYPE_CHECKING: + from google.cloud.bigtable.data.mutations import RowMutationEntry + + if CrossSync._Sync_Impl.is_async: + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) + + CrossSync._Sync_Impl.add_mapping("GapicClient", BigtableAsyncClient) + else: + from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient + + CrossSync._Sync_Impl.add_mapping("GapicClient", BigtableClient) + + +@CrossSync._Sync_Impl.add_mapping_decorator("_MutateRowsOperation") +class _MutateRowsOperation: + """ + MutateRowsOperation manages the logic of sending a set of row mutations, + and retrying on failed entries. It manages this using the _run_attempt + function, which attempts to mutate all outstanding entries, and raises + _MutateRowsIncomplete if any retryable errors are encountered. + + Errors are exposed as a MutationsExceptionGroup, which contains a list of + exceptions organized by the related failed mutation entries. + + Args: + gapic_client: the client to use for the mutate_rows call + table: the table associated with the request + mutation_entries: a list of RowMutationEntry objects to send to the server + operation_timeout: the timeout to use for the entire operation, in seconds. + attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds. + If not specified, the request will run until operation_timeout is reached. + """ + + def __init__( + self, + gapic_client: "CrossSync.GapicClient", + table: "CrossSync.Table", + mutation_entries: list["RowMutationEntry"], + operation_timeout: float, + attempt_timeout: float | None, + retryable_exceptions: Sequence[type[Exception]] = (), + ): + total_mutations = sum((len(entry.mutations) for entry in mutation_entries)) + if total_mutations > _MUTATE_ROWS_REQUEST_MUTATION_LIMIT: + raise ValueError( + f"mutate_rows requests can contain at most {_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across all entries. Found {total_mutations}." + ) + metadata = _make_metadata(table.table_name, table.app_profile_id) + self._gapic_fn = functools.partial( + gapic_client.mutate_rows, + table_name=table.table_name, + app_profile_id=table.app_profile_id, + metadata=metadata, + retry=None, + ) + self.is_retryable = retries.if_exception_type( + *retryable_exceptions, bt_exceptions._MutateRowsIncomplete + ) + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + self._operation = lambda: CrossSync._Sync_Impl.retry_target( + self._run_attempt, + self.is_retryable, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) + self.timeout_generator = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries] + self.remaining_indices = list(range(len(self.mutations))) + self.errors: dict[int, list[Exception]] = {} + + def start(self): + """Start the operation, and run until completion + + Raises: + MutationsExceptionGroup: if any mutations failed""" + try: + self._operation() + except Exception as exc: + incomplete_indices = self.remaining_indices.copy() + for idx in incomplete_indices: + self._handle_entry_error(idx, exc) + finally: + all_errors: list[Exception] = [] + for idx, exc_list in self.errors.items(): + if len(exc_list) == 0: + raise core_exceptions.ClientError( + f"Mutation {idx} failed with no associated errors" + ) + elif len(exc_list) == 1: + cause_exc = exc_list[0] + else: + cause_exc = bt_exceptions.RetryExceptionGroup(exc_list) + entry = self.mutations[idx].entry + all_errors.append( + bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc) + ) + if all_errors: + raise bt_exceptions.MutationsExceptionGroup( + all_errors, len(self.mutations) + ) + + def _run_attempt(self): + """Run a single attempt of the mutate_rows rpc. + + Raises: + _MutateRowsIncomplete: if there are failed mutations eligible for + retry after the attempt is complete + GoogleAPICallError: if the gapic rpc fails""" + request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] + active_request_indices = { + req_idx: orig_idx for req_idx, orig_idx in enumerate(self.remaining_indices) + } + self.remaining_indices = [] + if not request_entries: + return + try: + result_generator = self._gapic_fn( + timeout=next(self.timeout_generator), + entries=request_entries, + retry=None, + ) + for result_list in result_generator: + for result in result_list.entries: + orig_idx = active_request_indices[result.index] + entry_error = core_exceptions.from_grpc_status( + result.status.code, + result.status.message, + details=result.status.details, + ) + if result.status.code != 0: + self._handle_entry_error(orig_idx, entry_error) + elif orig_idx in self.errors: + del self.errors[orig_idx] + del active_request_indices[result.index] + except Exception as exc: + for idx in active_request_indices.values(): + self._handle_entry_error(idx, exc) + raise + if self.remaining_indices: + raise bt_exceptions._MutateRowsIncomplete + + def _handle_entry_error(self, idx: int, exc: Exception): + """Add an exception to the list of exceptions for a given mutation index, + and add the index to the list of remaining indices if the exception is + retryable. + + Args: + idx: the index of the mutation that failed + exc: the exception to add to the list""" + entry = self.mutations[idx].entry + self.errors.setdefault(idx, []).append(exc) + if ( + entry.is_idempotent() + and self.is_retryable(exc) + and (idx not in self.remaining_indices) + ): + self.remaining_indices.append(idx) diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py new file mode 100644 index 000000000..08e0dfbb2 --- /dev/null +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -0,0 +1,305 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by CrossSync. Do not edit manually. +from __future__ import annotations +from typing import Sequence +from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB +from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB +from google.cloud.bigtable_v2.types import RowSet as RowSetPB +from google.cloud.bigtable_v2.types import RowRange as RowRangePB +from google.cloud.bigtable.data.row import Row, Cell +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.exceptions import _RowSetComplete +from google.cloud.bigtable.data.exceptions import _ResetRow +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.api_core import retry as retries +from google.api_core.retry import exponential_sleep_generator +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + + +@CrossSync._Sync_Impl.add_mapping_decorator("_ReadRowsOperation") +class _ReadRowsOperation: + """ + ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream + into a stream of Row objects. + + ReadRowsOperation.merge_row_response_stream takes in a stream of ReadRowsResponse + and turns them into a stream of Row objects using an internal + StateMachine. + + ReadRowsOperation(request, client) handles row merging logic end-to-end, including + performing retries on stream errors. + + Args: + query: The query to execute + table: The table to send the request to + operation_timeout: The total time to allow for the operation, in seconds + attempt_timeout: The time to allow for each individual attempt, in seconds + retryable_exceptions: A list of exceptions that should trigger a retry + """ + + __slots__ = ( + "attempt_timeout_gen", + "operation_timeout", + "request", + "table", + "_predicate", + "_metadata", + "_last_yielded_row_key", + "_remaining_count", + ) + + def __init__( + self, + query: ReadRowsQuery, + table: "CrossSync.Table", + operation_timeout: float, + attempt_timeout: float, + retryable_exceptions: Sequence[type[Exception]] = (), + ): + self.attempt_timeout_gen = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + self.operation_timeout = operation_timeout + if isinstance(query, dict): + self.request = ReadRowsRequestPB( + **query, + table_name=table.table_name, + app_profile_id=table.app_profile_id, + ) + else: + self.request = query._to_pb(table) + self.table = table + self._predicate = retries.if_exception_type(*retryable_exceptions) + self._metadata = _make_metadata(table.table_name, table.app_profile_id) + self._last_yielded_row_key: bytes | None = None + self._remaining_count: int | None = self.request.rows_limit or None + + def start_operation(self) -> CrossSync._Sync_Impl.Iterable[Row]: + """Start the read_rows operation, retrying on retryable errors. + + Yields: + Row: The next row in the stream""" + return CrossSync._Sync_Impl.retry_target_stream( + self._read_rows_attempt, + self._predicate, + exponential_sleep_generator(0.01, 60, multiplier=2), + self.operation_timeout, + exception_factory=_retry_exception_factory, + ) + + def _read_rows_attempt(self) -> CrossSync._Sync_Impl.Iterable[Row]: + """Attempt a single read_rows rpc call. + This function is intended to be wrapped by retry logic, + which will call this function until it succeeds or + a non-retryable error is raised. + + Yields: + Row: The next row in the stream""" + if self._last_yielded_row_key is not None: + try: + self.request.rows = self._revise_request_rowset( + row_set=self.request.rows, + last_seen_row_key=self._last_yielded_row_key, + ) + except _RowSetComplete: + return self.merge_rows(None) + if self._remaining_count is not None: + self.request.rows_limit = self._remaining_count + if self._remaining_count == 0: + return self.merge_rows(None) + gapic_stream = self.table.client._gapic_client.read_rows( + self.request, + timeout=next(self.attempt_timeout_gen), + metadata=self._metadata, + retry=None, + ) + chunked_stream = self.chunk_stream(gapic_stream) + return self.merge_rows(chunked_stream) + + def chunk_stream( + self, + stream: CrossSync._Sync_Impl.Awaitable[ + CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB] + ], + ) -> CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB.CellChunk]: + """process chunks out of raw read_rows stream + + Args: + stream: the raw read_rows stream from the gapic client + Yields: + ReadRowsResponsePB.CellChunk: the next chunk in the stream""" + for resp in stream: + resp = resp._pb + if resp.last_scanned_row_key: + if ( + self._last_yielded_row_key is not None + and resp.last_scanned_row_key <= self._last_yielded_row_key + ): + raise InvalidChunk("last scanned out of order") + self._last_yielded_row_key = resp.last_scanned_row_key + current_key = None + for c in resp.chunks: + if current_key is None: + current_key = c.row_key + if current_key is None: + raise InvalidChunk("first chunk is missing a row key") + elif ( + self._last_yielded_row_key + and current_key <= self._last_yielded_row_key + ): + raise InvalidChunk("row keys should be strictly increasing") + yield c + if c.reset_row: + current_key = None + elif c.commit_row: + self._last_yielded_row_key = current_key + if self._remaining_count is not None: + self._remaining_count -= 1 + if self._remaining_count < 0: + raise InvalidChunk("emit count exceeds row limit") + current_key = None + + @staticmethod + def merge_rows( + chunks: CrossSync._Sync_Impl.Iterable[ReadRowsResponsePB.CellChunk] | None, + ) -> CrossSync._Sync_Impl.Iterable[Row]: + """Merge chunks into rows + + Args: + chunks: the chunk stream to merge + Yields: + Row: the next row in the stream""" + if chunks is None: + return + it = chunks.__iter__() + while True: + try: + c = it.__next__() + except CrossSync._Sync_Impl.StopIteration: + return + row_key = c.row_key + if not row_key: + raise InvalidChunk("first row chunk is missing key") + cells = [] + family: str | None = None + qualifier: bytes | None = None + try: + while True: + if c.reset_row: + raise _ResetRow(c) + k = c.row_key + f = c.family_name.value + q = c.qualifier.value if c.HasField("qualifier") else None + if k and k != row_key: + raise InvalidChunk("unexpected new row key") + if f: + family = f + if q is not None: + qualifier = q + else: + raise InvalidChunk("new family without qualifier") + elif family is None: + raise InvalidChunk("missing family") + elif q is not None: + if family is None: + raise InvalidChunk("new qualifier without family") + qualifier = q + elif qualifier is None: + raise InvalidChunk("missing qualifier") + ts = c.timestamp_micros + labels = c.labels if c.labels else [] + value = c.value + if c.value_size > 0: + buffer = [value] + while c.value_size > 0: + c = it.__next__() + t = c.timestamp_micros + cl = c.labels + k = c.row_key + if ( + c.HasField("family_name") + and c.family_name.value != family + ): + raise InvalidChunk("family changed mid cell") + if ( + c.HasField("qualifier") + and c.qualifier.value != qualifier + ): + raise InvalidChunk("qualifier changed mid cell") + if t and t != ts: + raise InvalidChunk("timestamp changed mid cell") + if cl and cl != labels: + raise InvalidChunk("labels changed mid cell") + if k and k != row_key: + raise InvalidChunk("row key changed mid cell") + if c.reset_row: + raise _ResetRow(c) + buffer.append(c.value) + value = b"".join(buffer) + cells.append( + Cell(value, row_key, family, qualifier, ts, list(labels)) + ) + if c.commit_row: + yield Row(row_key, cells) + break + c = it.__next__() + except _ResetRow as e: + c = e.chunk + if ( + c.row_key + or c.HasField("family_name") + or c.HasField("qualifier") + or c.timestamp_micros + or c.labels + or c.value + ): + raise InvalidChunk("reset row with data") + continue + except CrossSync._Sync_Impl.StopIteration: + raise InvalidChunk("premature end of stream") + + @staticmethod + def _revise_request_rowset(row_set: RowSetPB, last_seen_row_key: bytes) -> RowSetPB: + """Revise the rows in the request to avoid ones we've already processed. + + Args: + row_set: the row set from the request + last_seen_row_key: the last row key encountered + Returns: + RowSetPB: the new rowset after adusting for the last seen key + Raises: + _RowSetComplete: if there are no rows left to process after the revision""" + if row_set is None or (not row_set.row_ranges and (not row_set.row_keys)): + last_seen = last_seen_row_key + return RowSetPB(row_ranges=[RowRangePB(start_key_open=last_seen)]) + adjusted_keys: list[bytes] = [ + k for k in row_set.row_keys if k > last_seen_row_key + ] + adjusted_ranges: list[RowRangePB] = [] + for row_range in row_set.row_ranges: + end_key = row_range.end_key_closed or row_range.end_key_open or None + if end_key is None or end_key > last_seen_row_key: + new_range = RowRangePB(row_range) + start_key = row_range.start_key_closed or row_range.start_key_open + if start_key is None or start_key <= last_seen_row_key: + new_range.start_key_open = last_seen_row_key + adjusted_ranges.append(new_range) + if len(adjusted_keys) == 0 and len(adjusted_ranges) == 0: + raise _RowSetComplete() + return RowSetPB(row_keys=adjusted_keys, row_ranges=adjusted_ranges) diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py new file mode 100644 index 000000000..4bbd5498b --- /dev/null +++ b/google/cloud/bigtable/data/_sync/client.py @@ -0,0 +1,1197 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by CrossSync. Do not edit manually. +from __future__ import annotations +from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING +import time +import warnings +import random +import os +import concurrent.futures +from functools import partial +from grpc import Channel +from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta +from google.cloud.bigtable_v2.services.bigtable.transports.base import ( + DEFAULT_CLIENT_INFO, +) +from google.cloud.bigtable_v2.types.bigtable import PingAndWarmRequest +from google.cloud.client import ClientWithProject +from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.api_core import retry as retries +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import ServiceUnavailable +from google.api_core.exceptions import Aborted +import google.auth.credentials +import google.auth._default +from google.api_core import client_options as client_options_lib +from google.cloud.bigtable.client import _DEFAULT_BIGTABLE_EMULATOR_CLIENT +from google.cloud.bigtable.data.row import Row +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data.exceptions import FailedQueryShardError +from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _WarmedInstanceKey +from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT +from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data._helpers import _validate_timeouts +from google.cloud.bigtable.data._helpers import _get_retryable_errors +from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import _MB_SIZE +from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry +from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule +from google.cloud.bigtable.data.row_filters import RowFilter +from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter +from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter +from google.cloud.bigtable.data.row_filters import RowFilterChain +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +if CrossSync._Sync_Impl.is_async: + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledBigtableGrpcAsyncIOTransport, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel as AsyncPooledChannel, + ) + from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, + ) + from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync + from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) + + CrossSync._Sync_Impl.add_mapping("GapicClient", BigtableAsyncClient) + CrossSync._Sync_Impl.add_mapping( + "PooledTransport", PooledBigtableGrpcAsyncIOTransport + ) + CrossSync._Sync_Impl.add_mapping("PooledChannel", AsyncPooledChannel) + CrossSync._Sync_Impl.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) + CrossSync._Sync_Impl.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) + CrossSync._Sync_Impl.add_mapping("MutationsBatcher", MutationsBatcherAsync) +else: + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledBigtableGrpcTransport, + ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledChannel, + ) + from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation + from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher + + CrossSync._Sync_Impl.add_mapping("GapicClient", BigtableClient) + CrossSync._Sync_Impl.add_mapping("PooledTransport", PooledBigtableGrpcTransport) + CrossSync._Sync_Impl.add_mapping("PooledChannel", PooledChannel) + CrossSync._Sync_Impl.add_mapping("_ReadRowsOperation", _ReadRowsOperation) + CrossSync._Sync_Impl.add_mapping("_MutateRowsOperation", _MutateRowsOperation) + CrossSync._Sync_Impl.add_mapping("MutationsBatcher", MutationsBatcher) +if TYPE_CHECKING: + from google.cloud.bigtable.data._helpers import RowKeySamples + from google.cloud.bigtable.data._helpers import ShardedQuery + + +@CrossSync._Sync_Impl.add_mapping_decorator("DataClient") +class BigtableDataClient(ClientWithProject): + def __init__( + self, + *, + project: str | None = None, + pool_size: int = 3, + credentials: google.auth.credentials.Credentials | None = None, + client_options: dict[str, Any] + | "google.api_core.client_options.ClientOptions" + | None = None, + ): + """Create a client instance for the Bigtable Data API + + Client should be created within an async context (running event loop) + + Args: + project: the project which the client acts on behalf of. + If not passed, falls back to the default inferred + from the environment. + pool_size: The number of grpc channels to maintain + in the internal channel pool. + credentials: + Thehe OAuth2 Credentials to use for this + client. If not passed (and if no ``_http`` object is + passed), falls back to the default inferred from the + environment. + client_options: + Client options used to set user options + on the client. API Endpoint should be set through client_options. + Raises: + RuntimeError: if called outside of an async context (no running event loop) + ValueError: if pool_size is less than 1""" + transport_str = f"bt-{self._client_version()}-{pool_size}" + transport = CrossSync._Sync_Impl.PooledTransport.with_fixed_size(pool_size) + BigtableClientMeta._transport_registry[transport_str] = transport + client_info = DEFAULT_CLIENT_INFO + client_info.client_library_version = self._client_version() + if type(client_options) is dict: + client_options = client_options_lib.from_dict(client_options) + client_options = cast( + Optional[client_options_lib.ClientOptions], client_options + ) + self._emulator_host = os.getenv(BIGTABLE_EMULATOR) + if self._emulator_host is not None: + if credentials is None: + credentials = google.auth.credentials.AnonymousCredentials() + if project is None: + project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT + ClientWithProject.__init__( + self, + credentials=credentials, + project=project, + client_options=client_options, + ) + self._gapic_client = CrossSync._Sync_Impl.GapicClient( + transport=transport_str, + credentials=credentials, + client_options=client_options, + client_info=client_info, + ) + self._is_closed = CrossSync._Sync_Impl.Event() + self.transport = cast( + CrossSync._Sync_Impl.PooledTransport, self._gapic_client.transport + ) + self._active_instances: Set[_WarmedInstanceKey] = set() + self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} + self._channel_init_time = time.monotonic() + self._channel_refresh_tasks: list[CrossSync._Sync_Impl.Task[None]] = [] + self._executor = ( + concurrent.futures.ThreadPoolExecutor() + if not CrossSync._Sync_Impl.is_async + else None + ) + if self._emulator_host is not None: + warnings.warn( + "Connecting to Bigtable emulator at {}".format(self._emulator_host), + RuntimeWarning, + stacklevel=2, + ) + self.transport._grpc_channel = CrossSync._Sync_Impl.PooledChannel( + pool_size=pool_size, host=self._emulator_host, insecure=True + ) + self.transport._stubs = {} + self.transport._prep_wrapped_messages(client_info) + else: + try: + self._start_background_channel_refresh() + except RuntimeError: + warnings.warn( + f"{self.__class__.__name__} should be started in an asyncio event loop. Channel refresh will not be started", + RuntimeWarning, + stacklevel=2, + ) + + @staticmethod + def _client_version() -> str: + """Helper function to return the client version string for this client""" + version_str = f"{google.cloud.bigtable.__version__}-data" + if CrossSync._Sync_Impl.is_async: + version_str += "-async" + return version_str + + def _start_background_channel_refresh(self) -> None: + """Starts a background task to ping and warm each channel in the pool + + Raises: + RuntimeError: if not called in an asyncio event loop""" + if ( + not self._channel_refresh_tasks + and (not self._emulator_host) + and (not self._is_closed.is_set()) + ): + CrossSync._Sync_Impl.verify_async_event_loop() + for channel_idx in range(self.transport.pool_size): + refresh_task = CrossSync._Sync_Impl.create_task( + self._manage_channel, + channel_idx, + sync_executor=self._executor, + task_name=f"{self.__class__.__name__} channel refresh {channel_idx}", + ) + self._channel_refresh_tasks.append(refresh_task) + + def close(self, timeout: float | None = 2.0): + """Cancel all background tasks""" + self._is_closed.set() + for task in self._channel_refresh_tasks: + task.cancel() + self.transport.close() + if self._executor: + self._executor.shutdown(wait=False) + CrossSync._Sync_Impl.wait(self._channel_refresh_tasks, timeout=timeout) + self._channel_refresh_tasks = [] + + def _ping_and_warm_instances( + self, channel: Channel, instance_key: _WarmedInstanceKey | None = None + ) -> list[BaseException | None]: + """Prepares the backend for requests on a channel + + Pings each Bigtable instance registered in `_active_instances` on the client + + Args: + channel: grpc channel to warm + instance_key: if provided, only warm the instance associated with the key + Returns: + list[BaseException | None]: sequence of results or exceptions from the ping requests + """ + instance_list = ( + [instance_key] if instance_key is not None else self._active_instances + ) + ping_rpc = channel.unary_unary( + "/google.bigtable.v2.Bigtable/PingAndWarm", + request_serializer=PingAndWarmRequest.serialize, + ) + partial_list = [ + partial( + ping_rpc, + request={"name": instance_name, "app_profile_id": app_profile_id}, + metadata=[ + ( + "x-goog-request-params", + f"name={instance_name}&app_profile_id={app_profile_id}", + ) + ], + wait_for_ready=True, + ) + for instance_name, table_name, app_profile_id in instance_list + ] + result_list = CrossSync._Sync_Impl.gather_partials( + partial_list, return_exceptions=True, sync_executor=self._executor + ) + return [r or None for r in result_list] + + def _manage_channel( + self, + channel_idx: int, + refresh_interval_min: float = 60 * 35, + refresh_interval_max: float = 60 * 45, + grace_period: float = 60 * 10, + ) -> None: + """Background coroutine that periodically refreshes and warms a grpc channel + + The backend will automatically close channels after 60 minutes, so + `refresh_interval` + `grace_period` should be < 60 minutes + + Runs continuously until the client is closed + + Args: + channel_idx: index of the channel in the transport's channel pool + refresh_interval_min: minimum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + refresh_interval_max: maximum interval before initiating refresh + process in seconds. Actual interval will be a random value + between `refresh_interval_min` and `refresh_interval_max` + grace_period: time to allow previous channel to serve existing + requests before closing, in seconds""" + first_refresh = self._channel_init_time + random.uniform( + refresh_interval_min, refresh_interval_max + ) + next_sleep = max(first_refresh - time.monotonic(), 0) + if next_sleep > 0: + channel = self.transport.channels[channel_idx] + self._ping_and_warm_instances(channel) + while not self._is_closed.is_set(): + CrossSync._Sync_Impl.event_wait( + self._is_closed, next_sleep, async_break_early=False + ) + if self._is_closed.is_set(): + break + new_channel = self.transport.grpc_channel._create_channel() + self._ping_and_warm_instances(new_channel) + start_timestamp = time.monotonic() + self.transport.replace_channel( + channel_idx, + grace=grace_period, + new_channel=new_channel, + event=self._is_closed, + ) + next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) + next_sleep = next_refresh - (time.monotonic() - start_timestamp) + + def _register_instance(self, instance_id: str, owner: Table) -> None: + """Registers an instance with the client, and warms the channel pool + for the instance + The client will periodically refresh grpc channel pool used to make + requests, and new channels will be warmed for each registered instance + Channels will not be refreshed unless at least one instance is registered + + Args: + instance_id: id of the instance to register. + owner: table that owns the instance. Owners will be tracked in + _instance_owners, and instances will only be unregistered when all + owners call _remove_instance_registration""" + instance_name = self._gapic_client.instance_path(self.project, instance_id) + instance_key = _WarmedInstanceKey( + instance_name, owner.table_name, owner.app_profile_id + ) + self._instance_owners.setdefault(instance_key, set()).add(id(owner)) + if instance_name not in self._active_instances: + self._active_instances.add(instance_key) + if self._channel_refresh_tasks: + for channel in self.transport.channels: + self._ping_and_warm_instances(channel, instance_key) + else: + self._start_background_channel_refresh() + + def _remove_instance_registration(self, instance_id: str, owner: Table) -> bool: + """Removes an instance from the client's registered instances, to prevent + warming new channels for the instance + + If instance_id is not registered, or is still in use by other tables, returns False + + Args: + instance_id: id of the instance to remove + owner: table that owns the instance. Owners will be tracked in + _instance_owners, and instances will only be unregistered when all + owners call _remove_instance_registration + Returns: + bool: True if instance was removed, else False""" + instance_name = self._gapic_client.instance_path(self.project, instance_id) + instance_key = _WarmedInstanceKey( + instance_name, owner.table_name, owner.app_profile_id + ) + owner_list = self._instance_owners.get(instance_key, set()) + try: + owner_list.remove(id(owner)) + if len(owner_list) == 0: + self._active_instances.remove(instance_key) + return True + except KeyError: + return False + + def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> Table: + """Returns a table instance for making data API requests. All arguments are passed + directly to the Table constructor. + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults to 20 seconds + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to 60 seconds + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + Returns: + TableAsync: a table instance for making data API requests + Raises: + RuntimeError: if called outside of an async context (no running event loop) + """ + return Table(self, instance_id, table_id, *args, **kwargs) + + def __enter__(self): + self._start_background_channel_refresh() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + self._gapic_client.__exit__(exc_type, exc_val, exc_tb) + + +@CrossSync._Sync_Impl.add_mapping_decorator("Table") +class Table: + """ + Main Data API surface + + Table object maintains table_id, and app_profile_id context, and passes them with + each call + """ + + def __init__( + self, + client: BigtableDataClient, + instance_id: str, + table_id: str, + app_profile_id: str | None = None, + *, + default_read_rows_operation_timeout: float = 600, + default_read_rows_attempt_timeout: float | None = 20, + default_mutate_rows_operation_timeout: float = 600, + default_mutate_rows_attempt_timeout: float | None = 60, + default_operation_timeout: float = 60, + default_attempt_timeout: float | None = 20, + default_read_rows_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + Aborted, + ), + default_mutate_rows_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + ), + default_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + ), + ): + """Initialize a Table instance + + Must be created within an async context (running event loop) + + Args: + instance_id: The Bigtable instance ID to associate with this client. + instance_id is combined with the client's project to fully + specify the instance + table_id: The ID of the table. table_id is combined with the + instance_id and the client's project to fully specify the table + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + default_read_rows_operation_timeout: The default timeout for read rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_read_rows_attempt_timeout: The default timeout for individual + read rows rpc requests, in seconds. If not set, defaults to 20 seconds + default_mutate_rows_operation_timeout: The default timeout for mutate rows + operations, in seconds. If not set, defaults to 600 seconds (10 minutes) + default_mutate_rows_attempt_timeout: The default timeout for individual + mutate rows rpc requests, in seconds. If not set, defaults to 60 seconds + default_operation_timeout: The default timeout for all other operations, in + seconds. If not set, defaults to 60 seconds + default_attempt_timeout: The default timeout for all other individual rpc + requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + Raises: + RuntimeError: if called outside of an async context (no running event loop) + """ + _validate_timeouts( + default_operation_timeout, default_attempt_timeout, allow_none=True + ) + _validate_timeouts( + default_read_rows_operation_timeout, + default_read_rows_attempt_timeout, + allow_none=True, + ) + _validate_timeouts( + default_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout, + allow_none=True, + ) + self.client = client + self.instance_id = instance_id + self.instance_name = self.client._gapic_client.instance_path( + self.client.project, instance_id + ) + self.table_id = table_id + self.table_name = self.client._gapic_client.table_path( + self.client.project, instance_id, table_id + ) + self.app_profile_id = app_profile_id + self.default_operation_timeout = default_operation_timeout + self.default_attempt_timeout = default_attempt_timeout + self.default_read_rows_operation_timeout = default_read_rows_operation_timeout + self.default_read_rows_attempt_timeout = default_read_rows_attempt_timeout + self.default_mutate_rows_operation_timeout = ( + default_mutate_rows_operation_timeout + ) + self.default_mutate_rows_attempt_timeout = default_mutate_rows_attempt_timeout + self.default_read_rows_retryable_errors = ( + default_read_rows_retryable_errors or () + ) + self.default_mutate_rows_retryable_errors = ( + default_mutate_rows_retryable_errors or () + ) + self.default_retryable_errors = default_retryable_errors or () + try: + self._register_instance_future = CrossSync._Sync_Impl.create_task( + self.client._register_instance, + self.instance_id, + self, + sync_executor=self.client._executor, + ) + except RuntimeError as e: + raise RuntimeError( + f"{self.__class__.__name__} must be created within an async event loop context." + ) from e + + def read_rows_stream( + self, + query: ReadRowsQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> Iterable[Row]: + """Read a set of rows from the table, based on the specified query. + Returns an iterator to asynchronously stream back row data. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Args: + query: contains details about which rows to return + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors + Returns: + AsyncIterable[Row]: an asynchronous iterator that yields rows returned by the query + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error + """ + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + retryable_excs = _get_retryable_errors(retryable_errors, self) + row_merger = CrossSync._Sync_Impl._ReadRowsOperation( + query, + self, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_exceptions=retryable_excs, + ) + return row_merger.start_operation() + + def read_rows( + self, + query: ReadRowsQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> list[Row]: + """Read a set of rows from the table, based on the specified query. + Retruns results as a list of Row objects when the request is complete. + For streamed results, use read_rows_stream. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Args: + query: contains details about which rows to return + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + If None, defaults to the Table's default_read_rows_attempt_timeout, + or the operation_timeout if that is also None. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + list[Row]: a list of Rows returned by the query + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error + """ + row_generator = self.read_rows_stream( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + return [row for row in row_generator] + + def read_row( + self, + row_key: str | bytes, + *, + row_filter: RowFilter | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> Row | None: + """Read a single row from the table, based on the specified key. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Args: + query: contains details about which rows to return + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + Row | None: a Row object if the row exists, otherwise None + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error + """ + if row_key is None: + raise ValueError("row_key must be string or bytes") + query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) + results = self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + if len(results) == 0: + return None + return results[0] + + def read_rows_sharded( + self, + sharded_query: ShardedQuery, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> list[Row]: + """Runs a sharded query in parallel, then return the results in a single list. + Results will be returned in the order of the input queries. + + This function is intended to be run on the results on a query.shard() call. + For example:: + + table_shard_keys = await table.sample_row_keys() + query = ReadRowsQuery(...) + shard_queries = query.shard(table_shard_keys) + results = await table.read_rows_sharded(shard_queries) + + Args: + sharded_query: a sharded query to execute + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + list[Row]: a list of Rows returned by the query + Raises: + ShardedReadRowsExceptionGroup: if any of the queries failed + ValueError: if the query_list is empty""" + if not sharded_query: + raise ValueError("empty sharded_query") + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + rpc_timeout_generator = _attempt_timeout_generator( + operation_timeout, operation_timeout + ) + concurrency_sem = CrossSync._Sync_Impl.Semaphore(_CONCURRENCY_LIMIT) + + def read_rows_with_semaphore(query): + with concurrency_sem: + shard_timeout = next(rpc_timeout_generator) + if shard_timeout <= 0: + raise DeadlineExceeded( + "Operation timeout exceeded before starting query" + ) + return self.read_rows( + query, + operation_timeout=shard_timeout, + attempt_timeout=min(attempt_timeout, shard_timeout), + retryable_errors=retryable_errors, + ) + + routine_list = [ + partial(read_rows_with_semaphore, query) for query in sharded_query + ] + batch_result = CrossSync._Sync_Impl.gather_partials( + routine_list, return_exceptions=True, sync_executor=self.client._executor + ) + error_dict = {} + shard_idx = 0 + results_list = [] + for result in batch_result: + if isinstance(result, Exception): + error_dict[shard_idx] = result + elif isinstance(result, BaseException): + raise result + else: + results_list.extend(result) + shard_idx += 1 + if error_dict: + raise ShardedReadRowsExceptionGroup( + [ + FailedQueryShardError(idx, sharded_query[idx], e) + for idx, e in error_dict.items() + ], + results_list, + len(sharded_query), + ) + return results_list + + def row_exists( + self, + row_key: str | bytes, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + ) -> bool: + """Return a boolean indicating whether the specified row exists in the table. + uses the filters: chain(limit cells per row = 1, strip value) + + Args: + row_key: the key of the row to check + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_read_rows_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_read_rows_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. + Returns: + bool: a bool indicating whether the row exists + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error + """ + if row_key is None: + raise ValueError("row_key must be string or bytes") + strip_filter = StripValueTransformerFilter(flag=True) + limit_filter = CellsRowLimitFilter(1) + chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) + query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) + results = self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, + ) + return len(results) > 0 + + def sample_row_keys( + self, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> RowKeySamples: + """Return a set of RowKeySamples that delimit contiguous sections of the table of + approximately equal size + + RowKeySamples output can be used with ReadRowsQuery.shard() to create a sharded query that + can be parallelized across multiple backend nodes read_rows and read_rows_stream + requests will call sample_row_keys internally for this purpose when sharding is enabled + + RowKeySamples is simply a type alias for list[tuple[bytes, int]]; a list of + row_keys, along with offset positions in the table + + Args: + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget.i + Defaults to the Table's default_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_retryable_errors. + Returns: + RowKeySamples: a set of RowKeySamples the delimit contiguous sections of the table + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error + """ + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + attempt_timeout_gen = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + retryable_excs = _get_retryable_errors(retryable_errors, self) + predicate = retries.if_exception_type(*retryable_excs) + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + metadata = _make_metadata(self.table_name, self.app_profile_id) + + def execute_rpc(): + results = self.client._gapic_client.sample_row_keys( + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=next(attempt_timeout_gen), + metadata=metadata, + retry=None, + ) + return [(s.row_key, s.offset_bytes) for s in results] + + return CrossSync._Sync_Impl.retry_target( + execute_rpc, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) + + def mutations_batcher( + self, + *, + flush_interval: float | None = 5, + flush_limit_mutation_count: int | None = 1000, + flush_limit_bytes: int = 20 * _MB_SIZE, + flow_control_max_mutation_count: int = 100000, + flow_control_max_bytes: int = 100 * _MB_SIZE, + batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ) -> MutationsBatcher: + """Returns a new mutations batcher instance. + + Can be used to iteratively add mutations that are flushed as a group, + to avoid excess network calls + + Args: + flush_interval: Automatically flush every flush_interval seconds. If None, + a table default will be used + flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count + mutations are added across all entries. If None, this limit is ignored. + flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. + flow_control_max_mutation_count: Maximum number of inflight mutations. + flow_control_max_bytes: Maximum number of inflight bytes. + batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + Defaults to the Table's default_mutate_rows_operation_timeout + batch_attempt_timeout: timeout for each individual request, in seconds. + Defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to batch_operation_timeout. + batch_retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors. + Returns: + MutationsBatcherAsync: a MutationsBatcher context manager that can batch requests + """ + return CrossSync._Sync_Impl.MutationsBatcher( + self, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_mutation_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=batch_operation_timeout, + batch_attempt_timeout=batch_attempt_timeout, + batch_retryable_errors=batch_retryable_errors, + ) + + def mutate_row( + self, + row_key: str | bytes, + mutations: list[Mutation] | Mutation, + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ): + """Mutates a row atomically. + + Cells already present in the row are left unchanged unless explicitly changed + by ``mutation``. + + Idempotent operations (i.e, all mutations have an explicit timestamp) will be + retried on server failure. Non-idempotent operations will not. + + Args: + row_key: the row to apply mutations to + mutations: the set of mutations to apply to the row + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Only idempotent mutations will be retried. Defaults to the Table's + default_retryable_errors. + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing all + GoogleAPIError exceptions from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised on non-idempotent operations that cannot be + safely retried. + ValueError: if invalid arguments are provided""" + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + if not mutations: + raise ValueError("No mutations provided") + mutations_list = mutations if isinstance(mutations, list) else [mutations] + if all((mutation.is_idempotent() for mutation in mutations_list)): + predicate = retries.if_exception_type( + *_get_retryable_errors(retryable_errors, self) + ) + else: + predicate = retries.if_exception_type() + sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + target = partial( + self.client._gapic_client.mutate_row, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + mutations=[mutation._to_pb() for mutation in mutations_list], + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=attempt_timeout, + metadata=_make_metadata(self.table_name, self.app_profile_id), + retry=None, + ) + return CrossSync._Sync_Impl.retry_target( + target, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) + + def bulk_mutate_rows( + self, + mutation_entries: list[RowMutationEntry], + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ): + """Applies mutations for multiple rows in a single batched request. + + Each individual RowMutationEntry is applied atomically, but separate entries + may be applied in arbitrary order (even for entries targetting the same row) + In total, the row_mutations can contain at most 100000 individual mutations + across all entries + + Idempotent entries (i.e., entries with mutations with explicit timestamps) + will be retried on failure. Non-idempotent will not, and will reported in a + raised exception group + + Args: + mutation_entries: the batches of mutations to apply + Each entry will be applied atomically, but entries will be applied + in arbitrary order + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to the Table's default_mutate_rows_operation_timeout + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors + Raises: + MutationsExceptionGroup: if one or more mutations fails + Contains details about any failed entries in .exceptions + ValueError: if invalid arguments are provided""" + operation_timeout, attempt_timeout = _get_timeouts( + operation_timeout, attempt_timeout, self + ) + retryable_excs = _get_retryable_errors(retryable_errors, self) + operation = CrossSync._Sync_Impl._MutateRowsOperation( + self.client._gapic_client, + self, + mutation_entries, + operation_timeout, + attempt_timeout, + retryable_exceptions=retryable_excs, + ) + operation.start() + + def check_and_mutate_row( + self, + row_key: str | bytes, + predicate: RowFilter | None, + *, + true_case_mutations: Mutation | list[Mutation] | None = None, + false_case_mutations: Mutation | list[Mutation] | None = None, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> bool: + """Mutates a row atomically based on the output of a predicate filter + + Non-idempotent operation: will not be retried + + Args: + row_key: the key of the row to mutate + predicate: the filter to be applied to the contents of the specified row. + Depending on whether or not any results are yielded, + either true_case_mutations or false_case_mutations will be executed. + If None, checks that the row contains any values at all. + true_case_mutations: + Changes to be atomically applied to the specified row if + predicate yields at least one cell when + applied to row_key. Entries are applied in order, + meaning that earlier mutations can be masked by later + ones. Must contain at least one entry if + false_case_mutations is empty, and at most 100000. + false_case_mutations: + Changes to be atomically applied to the specified row if + predicate_filter does not yield any cells when + applied to row_key. Entries are applied in order, + meaning that earlier mutations can be masked by later + ones. Must contain at least one entry if + `true_case_mutations` is empty, and at most 100000. + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will not be retried. Defaults to the Table's default_operation_timeout + Returns: + bool indicating whether the predicate was true or false + Raises: + google.api_core.exceptions.GoogleAPIError: exceptions from grpc call""" + operation_timeout, _ = _get_timeouts(operation_timeout, None, self) + if true_case_mutations is not None and ( + not isinstance(true_case_mutations, list) + ): + true_case_mutations = [true_case_mutations] + true_case_list = [m._to_pb() for m in true_case_mutations or []] + if false_case_mutations is not None and ( + not isinstance(false_case_mutations, list) + ): + false_case_mutations = [false_case_mutations] + false_case_list = [m._to_pb() for m in false_case_mutations or []] + metadata = _make_metadata(self.table_name, self.app_profile_id) + result = self.client._gapic_client.check_and_mutate_row( + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) + return result.predicate_matched + + def read_modify_write_row( + self, + row_key: str | bytes, + rules: ReadModifyWriteRule | list[ReadModifyWriteRule], + *, + operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + ) -> Row: + """Reads and modifies a row atomically according to input ReadModifyWriteRules, + and returns the contents of all modified cells + + The new value for the timestamp is the greater of the existing timestamp or + the current server time. + + Non-idempotent operation: will not be retried + + Args: + row_key: the key of the row to apply read/modify/write rules to + rules: A rule or set of rules to apply to the row. + Rules are applied in order, meaning that earlier rules will affect the + results of later ones. + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will not be retried. + Defaults to the Table's default_operation_timeout. + Returns: + Row: a Row containing cell data that was modified as part of the operation + Raises: + google.api_core.exceptions.GoogleAPIError: exceptions from grpc call + ValueError: if invalid arguments are provided""" + operation_timeout, _ = _get_timeouts(operation_timeout, None, self) + if operation_timeout <= 0: + raise ValueError("operation_timeout must be greater than 0") + if rules is not None and (not isinstance(rules, list)): + rules = [rules] + if not rules: + raise ValueError("rules must contain at least one item") + metadata = _make_metadata(self.table_name, self.app_profile_id) + result = self.client._gapic_client.read_modify_write_row( + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) + return Row._from_pb(result.row) + + def close(self): + """Called to close the Table instance and release any resources held by it.""" + if self._register_instance_future: + self._register_instance_future.cancel() + self.client._remove_instance_registration(self.instance_id, self) + + def __enter__(self): + """Implement async context manager protocol + + Ensure registration task has time to run, so that + grpc channels will be warmed for the specified instance""" + if self._register_instance_future: + self._register_instance_future + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Implement async context manager protocol + + Unregister this instance with the client, so that + grpc channels will no longer be warmed""" + self.close() diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py new file mode 100644 index 000000000..232ca6e12 --- /dev/null +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -0,0 +1,449 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by CrossSync. Do not edit manually. + +# mypy: disable-error-code="unreachable" + +from __future__ import annotations +from typing import Sequence, TYPE_CHECKING +import atexit +import warnings +from collections import deque +import concurrent.futures +from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup +from google.cloud.bigtable.data.exceptions import FailedMutationEntryError +from google.cloud.bigtable.data._helpers import _get_retryable_errors +from google.cloud.bigtable.data._helpers import _get_timeouts +from google.cloud.bigtable.data._helpers import TABLE_DEFAULT +from google.cloud.bigtable.data._helpers import _MB_SIZE +from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT +from google.cloud.bigtable.data.mutations import Mutation +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +if TYPE_CHECKING: + from google.cloud.bigtable.data.mutations import RowMutationEntry + + if CrossSync._Sync_Impl.is_async: + pass + else: + from google.cloud.bigtable.data._sync.client import Table + + +@CrossSync._Sync_Impl.add_mapping_decorator("_FlowControl") +class _FlowControl: + """ + Manages flow control for batched mutations. Mutations are registered against + the FlowControl object before being sent, which will block if size or count + limits have reached capacity. As mutations completed, they are removed from + the FlowControl object, which will notify any blocked requests that there + is additional capacity. + + Flow limits are not hard limits. If a single mutation exceeds the configured + limits, it will be allowed as a single batch when the capacity is available. + + Args: + max_mutation_count: maximum number of mutations to send in a single rpc. + This corresponds to individual mutations in a single RowMutationEntry. + max_mutation_bytes: maximum number of bytes to send in a single rpc. + Raises: + ValueError: if max_mutation_count or max_mutation_bytes is less than 0 + """ + + def __init__(self, max_mutation_count: int, max_mutation_bytes: int): + self._max_mutation_count = max_mutation_count + self._max_mutation_bytes = max_mutation_bytes + if self._max_mutation_count < 1: + raise ValueError("max_mutation_count must be greater than 0") + if self._max_mutation_bytes < 1: + raise ValueError("max_mutation_bytes must be greater than 0") + self._capacity_condition = CrossSync._Sync_Impl.Condition() + self._in_flight_mutation_count = 0 + self._in_flight_mutation_bytes = 0 + + def _has_capacity(self, additional_count: int, additional_size: int) -> bool: + """Checks if there is capacity to send a new entry with the given size and count + + FlowControl limits are not hard limits. If a single mutation exceeds + the configured flow limits, it will be sent in a single batch when + previous batches have completed. + + Args: + additional_count: number of mutations in the pending entry + additional_size: size of the pending entry + Returns: + bool: True if there is capacity to send the pending entry, False otherwise + """ + acceptable_size = max(self._max_mutation_bytes, additional_size) + acceptable_count = max(self._max_mutation_count, additional_count) + new_size = self._in_flight_mutation_bytes + additional_size + new_count = self._in_flight_mutation_count + additional_count + return new_size <= acceptable_size and new_count <= acceptable_count + + def remove_from_flow( + self, mutations: RowMutationEntry | list[RowMutationEntry] + ) -> None: + """Removes mutations from flow control. This method should be called once + for each mutation that was sent to add_to_flow, after the corresponding + operation is complete. + + Args: + mutations: mutation or list of mutations to remove from flow control""" + if not isinstance(mutations, list): + mutations = [mutations] + total_count = sum((len(entry.mutations) for entry in mutations)) + total_size = sum((entry.size() for entry in mutations)) + self._in_flight_mutation_count -= total_count + self._in_flight_mutation_bytes -= total_size + with self._capacity_condition: + self._capacity_condition.notify_all() + + def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry]): + """Generator function that registers mutations with flow control. As mutations + are accepted into the flow control, they are yielded back to the caller, + to be sent in a batch. If the flow control is at capacity, the generator + will block until there is capacity available. + + Args: + mutations: list mutations to break up into batches + Yields: + list[RowMutationEntry]: + list of mutations that have reserved space in the flow control. + Each batch contains at least one mutation.""" + if not isinstance(mutations, list): + mutations = [mutations] + start_idx = 0 + end_idx = 0 + while end_idx < len(mutations): + start_idx = end_idx + batch_mutation_count = 0 + with self._capacity_condition: + while end_idx < len(mutations): + next_entry = mutations[end_idx] + next_size = next_entry.size() + next_count = len(next_entry.mutations) + if ( + self._has_capacity(next_count, next_size) + and batch_mutation_count + next_count + <= _MUTATE_ROWS_REQUEST_MUTATION_LIMIT + ): + end_idx += 1 + batch_mutation_count += next_count + self._in_flight_mutation_bytes += next_size + self._in_flight_mutation_count += next_count + elif start_idx != end_idx: + break + else: + self._capacity_condition.wait_for( + lambda: self._has_capacity(next_count, next_size) + ) + yield mutations[start_idx:end_idx] + + +@CrossSync._Sync_Impl.add_mapping_decorator("MutationsBatcher") +class MutationsBatcher: + """ + Allows users to send batches using context manager API: + + Runs mutate_row, mutate_rows, and check_and_mutate_row internally, combining + to use as few network requests as required + + Will automatically flush the batcher: + - every flush_interval seconds + - after queue size reaches flush_limit_mutation_count + - after queue reaches flush_limit_bytes + - when batcher is closed or destroyed + + Args: + table: Table to preform rpc calls + flush_interval: Automatically flush every flush_interval seconds. + If None, no time-based flushing is performed. + flush_limit_mutation_count: Flush immediately after flush_limit_mutation_count + mutations are added across all entries. If None, this limit is ignored. + flush_limit_bytes: Flush immediately after flush_limit_bytes bytes are added. + flow_control_max_mutation_count: Maximum number of inflight mutations. + flow_control_max_bytes: Maximum number of inflight bytes. + batch_operation_timeout: timeout for each mutate_rows operation, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_operation_timeout. + batch_attempt_timeout: timeout for each individual request, in seconds. + If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_attempt_timeout. + If None, defaults to batch_operation_timeout. + batch_retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors. + """ + + def __init__( + self, + table: Table, + *, + flush_interval: float | None = 5, + flush_limit_mutation_count: int | None = 1000, + flush_limit_bytes: int = 20 * _MB_SIZE, + flow_control_max_mutation_count: int = 100000, + flow_control_max_bytes: int = 100 * _MB_SIZE, + batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + ): + self._operation_timeout, self._attempt_timeout = _get_timeouts( + batch_operation_timeout, batch_attempt_timeout, table + ) + self._retryable_errors: list[type[Exception]] = _get_retryable_errors( + batch_retryable_errors, table + ) + self._closed = CrossSync._Sync_Impl.Event() + self._table = table + self._staged_entries: list[RowMutationEntry] = [] + self._staged_count, self._staged_bytes = (0, 0) + self._flow_control = CrossSync._Sync_Impl._FlowControl( + flow_control_max_mutation_count, flow_control_max_bytes + ) + self._flush_limit_bytes = flush_limit_bytes + self._flush_limit_count = ( + flush_limit_mutation_count + if flush_limit_mutation_count is not None + else float("inf") + ) + self._sync_executor = ( + concurrent.futures.ThreadPoolExecutor(max_workers=8) + if not CrossSync._Sync_Impl.is_async + else None + ) + self._flush_timer = CrossSync._Sync_Impl.create_task( + self._timer_routine, flush_interval, sync_executor=self._sync_executor + ) + self._flush_jobs: set[CrossSync._Sync_Impl.Future[None]] = set() + self._entries_processed_since_last_raise: int = 0 + self._exceptions_since_last_raise: int = 0 + self._exception_list_limit: int = 10 + self._oldest_exceptions: list[Exception] = [] + self._newest_exceptions: deque[Exception] = deque( + maxlen=self._exception_list_limit + ) + atexit.register(self._on_exit) + + def _timer_routine(self, interval: float | None) -> None: + """Set up a background task to flush the batcher every interval seconds + + If interval is None, an empty future is returned + + Args: + flush_interval: Automatically flush every flush_interval seconds. + If None, no time-based flushing is performed.""" + if not interval or interval <= 0: + return None + while not self._closed.is_set(): + CrossSync._Sync_Impl.event_wait( + self._closed, timeout=interval, async_break_early=False + ) + if not self._closed.is_set() and self._staged_entries: + self._schedule_flush() + + def append(self, mutation_entry: RowMutationEntry): + """Add a new set of mutations to the internal queue + + Args: + mutation_entry: new entry to add to flush queue + Raises: + RuntimeError: if batcher is closed + ValueError: if an invalid mutation type is added""" + if self._closed.is_set(): + raise RuntimeError("Cannot append to closed MutationsBatcher") + if isinstance(mutation_entry, Mutation): + raise ValueError( + f"invalid mutation type: {type(mutation_entry).__name__}. Only RowMutationEntry objects are supported by batcher" + ) + self._staged_entries.append(mutation_entry) + self._staged_count += len(mutation_entry.mutations) + self._staged_bytes += mutation_entry.size() + if ( + self._staged_count >= self._flush_limit_count + or self._staged_bytes >= self._flush_limit_bytes + ): + self._schedule_flush() + CrossSync._Sync_Impl.yield_to_event_loop() + + def _schedule_flush(self) -> CrossSync._Sync_Impl.Future[None] | None: + """Update the flush task to include the latest staged entries + + Returns: + Future[None] | None: + future representing the background task, if started""" + if self._staged_entries: + entries, self._staged_entries = (self._staged_entries, []) + self._staged_count, self._staged_bytes = (0, 0) + new_task = CrossSync._Sync_Impl.create_task( + self._flush_internal, entries, sync_executor=self._sync_executor + ) + if not new_task.done(): + self._flush_jobs.add(new_task) + new_task.add_done_callback(self._flush_jobs.remove) + return new_task + return None + + def _flush_internal(self, new_entries: list[RowMutationEntry]): + """Flushes a set of mutations to the server, and updates internal state + + Args: + new_entries list of RowMutationEntry objects to flush""" + in_process_requests: list[ + CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] + ] = [] + for batch in self._flow_control.add_to_flow(new_entries): + batch_task = CrossSync._Sync_Impl.create_task( + self._execute_mutate_rows, batch, sync_executor=self._sync_executor + ) + in_process_requests.append(batch_task) + found_exceptions = self._wait_for_batch_results(*in_process_requests) + self._entries_processed_since_last_raise += len(new_entries) + self._add_exceptions(found_exceptions) + + def _execute_mutate_rows( + self, batch: list[RowMutationEntry] + ) -> list[FailedMutationEntryError]: + """Helper to execute mutation operation on a batch + + Args: + batch: list of RowMutationEntry objects to send to server + timeout: timeout in seconds. Used as operation_timeout and attempt_timeout. + If not given, will use table defaults + Returns: + list[FailedMutationEntryError]: + list of FailedMutationEntryError objects for mutations that failed. + FailedMutationEntryError objects will not contain index information""" + try: + operation = CrossSync._Sync_Impl._MutateRowsOperation( + self._table.client._gapic_client, + self._table, + batch, + operation_timeout=self._operation_timeout, + attempt_timeout=self._attempt_timeout, + retryable_exceptions=self._retryable_errors, + ) + operation.start() + except MutationsExceptionGroup as e: + for subexc in e.exceptions: + subexc.index = None + return list(e.exceptions) + finally: + self._flow_control.remove_from_flow(batch) + return [] + + def _add_exceptions(self, excs: list[Exception]): + """Add new list of exceptions to internal store. To avoid unbounded memory, + the batcher will store the first and last _exception_list_limit exceptions, + and discard any in between. + + Args: + excs: list of exceptions to add to the internal store""" + self._exceptions_since_last_raise += len(excs) + if excs and len(self._oldest_exceptions) < self._exception_list_limit: + addition_count = self._exception_list_limit - len(self._oldest_exceptions) + self._oldest_exceptions.extend(excs[:addition_count]) + excs = excs[addition_count:] + if excs: + self._newest_exceptions.extend(excs[-self._exception_list_limit :]) + + def _raise_exceptions(self): + """Raise any unreported exceptions from background flush operations + + Raises: + MutationsExceptionGroup: exception group with all unreported exceptions""" + if self._oldest_exceptions or self._newest_exceptions: + oldest, self._oldest_exceptions = (self._oldest_exceptions, []) + newest = list(self._newest_exceptions) + self._newest_exceptions.clear() + entry_count, self._entries_processed_since_last_raise = ( + self._entries_processed_since_last_raise, + 0, + ) + exc_count, self._exceptions_since_last_raise = ( + self._exceptions_since_last_raise, + 0, + ) + raise MutationsExceptionGroup.from_truncated_lists( + first_list=oldest, + last_list=newest, + total_excs=exc_count, + entry_count=entry_count, + ) + + def __enter__(self): + """Allow use of context manager API""" + return self + + def __exit__(self, exc_type, exc, tb): + """Allow use of context manager API. + + Flushes the batcher and cleans up resources.""" + self.close() + + @property + def closed(self) -> bool: + """Returns: + - True if the batcher is closed, False otherwise""" + return self._closed.is_set() + + def close(self): + """Flush queue and clean up resources""" + self._closed.set() + self._flush_timer.cancel() + self._schedule_flush() + CrossSync._Sync_Impl.wait([*self._flush_jobs, self._flush_timer]) + if self._sync_executor: + with self._sync_executor: + self._sync_executor.shutdown(wait=True) + atexit.unregister(self._on_exit) + self._raise_exceptions() + + def _on_exit(self): + """Called when program is exited. Raises warning if unflushed mutations remain""" + if not self._closed.is_set() and self._staged_entries: + warnings.warn( + f"MutationsBatcher for table {self._table.table_name} was not closed. {len(self._staged_entries)} Unflushed mutations will not be sent to the server." + ) + + @staticmethod + def _wait_for_batch_results( + *tasks: CrossSync._Sync_Impl.Future[list[FailedMutationEntryError]] + | CrossSync._Sync_Impl.Future[None], + ) -> list[Exception]: + """Takes in a list of futures representing _execute_mutate_rows tasks, + waits for them to complete, and returns a list of errors encountered. + + Args: + *tasks: futures representing _execute_mutate_rows or _flush_internal tasks + Returns: + list[Exception]: + list of Exceptions encountered by any of the tasks. Errors are expected + to be FailedMutationEntryError, representing a failed mutation operation. + If a task fails with a different exception, it will be included in the + output list. Successful tasks will not be represented in the output list. + """ + if not tasks: + return [] + exceptions: list[Exception] = [] + for task in tasks: + if CrossSync._Sync_Impl.is_async: + task + try: + exc_list = task.result() + if exc_list: + for exc in exc_list: + exc.index = None + exceptions.extend(exc_list) + except Exception as e: + exceptions.append(e) + return exceptions diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py new file mode 100644 index 000000000..ac586ea47 --- /dev/null +++ b/tests/system/data/test_system.py @@ -0,0 +1,808 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by CrossSync. Do not edit manually. +import pytest +import uuid +import os +from google.api_core import retry +from google.api_core.exceptions import ClientError +from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE +from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from . import TEST_FAMILY, TEST_FAMILY_2 + + +@CrossSync._Sync_Impl.add_mapping_decorator("TempRowBuilder") +class TempRowBuilder: + """ + Used to add rows to a table for testing purposes. + """ + + def __init__(self, table): + self.rows = [] + self.table = table + + def add_row( + self, row_key, *, family=TEST_FAMILY, qualifier=b"q", value=b"test-value" + ): + if isinstance(value, str): + value = value.encode("utf-8") + elif isinstance(value, int): + value = value.to_bytes(8, byteorder="big", signed=True) + request = { + "table_name": self.table.table_name, + "row_key": row_key, + "mutations": [ + { + "set_cell": { + "family_name": family, + "column_qualifier": qualifier, + "value": value, + } + } + ], + } + self.table.client._gapic_client.mutate_row(request) + self.rows.append(row_key) + + def delete_rows(self): + if self.rows: + request = { + "table_name": self.table.table_name, + "entries": [ + {"row_key": row, "mutations": [{"delete_from_row": {}}]} + for row in self.rows + ], + } + self.table.client._gapic_client.mutate_rows(request) + + +class TestSystem: + @pytest.fixture(scope="session") + def client(self): + project = os.getenv("GOOGLE_CLOUD_PROJECT") or None + with CrossSync._Sync_Impl.DataClient(project=project, pool_size=4) as client: + yield client + + @pytest.fixture(scope="session") + def table(self, client, table_id, instance_id): + with client.get_table(instance_id, table_id) as table: + yield table + + @pytest.fixture(scope="session") + def column_family_config(self): + """specify column families to create when creating a new test table""" + from google.cloud.bigtable_admin_v2 import types + + return {TEST_FAMILY: types.ColumnFamily(), TEST_FAMILY_2: types.ColumnFamily()} + + @pytest.fixture(scope="session") + def init_table_id(self): + """The table_id to use when creating a new test table""" + return f"test-table-{uuid.uuid4().hex}" + + @pytest.fixture(scope="session") + def cluster_config(self, project_id): + """Configuration for the clusters to use when creating a new instance""" + from google.cloud.bigtable_admin_v2 import types + + cluster = { + "test-cluster": types.Cluster( + location=f"projects/{project_id}/locations/us-central1-b", serve_nodes=1 + ) + } + return cluster + + @pytest.mark.usefixtures("table") + def _retrieve_cell_value(self, table, row_key): + """Helper to read an individual row""" + from google.cloud.bigtable.data import ReadRowsQuery + + row_list = table.read_rows(ReadRowsQuery(row_keys=row_key)) + assert len(row_list) == 1 + row = row_list[0] + cell = row.cells[0] + return cell.value + + def _create_row_and_mutation( + self, table, temp_rows, *, start_value=b"start", new_value=b"new_value" + ): + """Helper to create a new row, and a sample set_cell mutation to change its value""" + from google.cloud.bigtable.data.mutations import SetCell + + row_key = uuid.uuid4().hex.encode() + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row( + row_key, family=family, qualifier=qualifier, value=start_value + ) + assert self._retrieve_cell_value(table, row_key) == start_value + mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) + return (row_key, mutation) + + @pytest.fixture(scope="function") + def temp_rows(self, table): + builder = CrossSync._Sync_Impl.TempRowBuilder(table) + yield builder + builder.delete_rows() + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=10 + ) + def test_ping_and_warm_gapic(self, client, table): + """Simple ping rpc test + This test ensures channels are able to authenticate with backend""" + request = {"name": table.instance_name} + client._gapic_client.ping_and_warm(request) + + @pytest.mark.usefixtures("table") + @pytest.mark.usefixtures("client") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_ping_and_warm(self, client, table): + """Test ping and warm from handwritten client""" + try: + channel = client.transport._grpc_channel.pool[0] + except Exception: + channel = client.transport._grpc_channel + results = client._ping_and_warm_instances(channel) + assert len(results) == 1 + assert results[0] is None + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutation_set_cell(self, table, temp_rows): + """Ensure cells can be set properly""" + row_key = b"bulk_mutate" + new_value = uuid.uuid4().hex.encode() + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + table.mutate_row(row_key, mutation) + assert self._retrieve_cell_value(table, row_key) == new_value + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), reason="emulator doesn't use splits" + ) + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_sample_row_keys(self, client, table, temp_rows, column_split_config): + """Sample keys should return a single sample in small test tables""" + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + results = table.sample_row_keys() + assert len(results) == len(column_split_config) + 1 + for idx in range(len(column_split_config)): + assert results[idx][0] == column_split_config[idx] + assert isinstance(results[idx][1], int) + assert results[-1][0] == b"" + assert isinstance(results[-1][1], int) + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_bulk_mutations_set_cell(self, client, table, temp_rows): + """Ensure cells can be set properly""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + table.bulk_mutate_rows([bulk_mutation]) + assert self._retrieve_cell_value(table, row_key) == new_value + + def test_bulk_mutations_raise_exception(self, client, table): + """If an invalid mutation is passed, an exception should be raised""" + from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + row_key = uuid.uuid4().hex.encode() + mutation = SetCell( + family="nonexistent", qualifier=b"test-qualifier", new_value=b"" + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + with pytest.raises(MutationsExceptionGroup) as exc: + table.bulk_mutate_rows([bulk_mutation]) + assert len(exc.value.exceptions) == 1 + entry_error = exc.value.exceptions[0] + assert isinstance(entry_error, FailedMutationEntryError) + assert entry_error.index == 0 + assert entry_error.entry == bulk_mutation + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_context_manager(self, client, table, temp_rows): + """test batcher with context manager. Should flush on exit""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + row_key2, mutation2 = self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + with table.mutations_batcher() as batcher: + batcher.append(bulk_mutation) + batcher.append(bulk_mutation2) + assert self._retrieve_cell_value(table, row_key) == new_value + assert len(batcher._staged_entries) == 0 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_timer_flush(self, client, table, temp_rows): + """batch should occur after flush_interval seconds""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + flush_interval = 0.1 + with table.mutations_batcher(flush_interval=flush_interval) as batcher: + batcher.append(bulk_mutation) + CrossSync._Sync_Impl.yield_to_event_loop() + assert len(batcher._staged_entries) == 1 + CrossSync._Sync_Impl.sleep(flush_interval + 0.1) + assert len(batcher._staged_entries) == 0 + assert self._retrieve_cell_value(table, row_key) == new_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_count_flush(self, client, table, temp_rows): + """batch should flush after flush_limit_mutation_count mutations""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + with table.mutations_batcher(flush_limit_mutation_count=2) as batcher: + batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + assert len(batcher._staged_entries) == 1 + batcher.append(bulk_mutation2) + assert len(batcher._flush_jobs) == 1 + for future in list(batcher._flush_jobs): + future + future.result() + assert len(batcher._staged_entries) == 0 + assert len(batcher._flush_jobs) == 0 + assert self._retrieve_cell_value(table, row_key) == new_value + assert self._retrieve_cell_value(table, row_key2) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): + """batch should flush after flush_limit_bytes bytes""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = self._create_row_and_mutation( + table, temp_rows, new_value=new_value2 + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + flush_limit = bulk_mutation.size() + bulk_mutation2.size() - 1 + with table.mutations_batcher(flush_limit_bytes=flush_limit) as batcher: + batcher.append(bulk_mutation) + assert len(batcher._flush_jobs) == 0 + assert len(batcher._staged_entries) == 1 + batcher.append(bulk_mutation2) + assert len(batcher._flush_jobs) == 1 + assert len(batcher._staged_entries) == 0 + for future in list(batcher._flush_jobs): + future + future.result() + assert self._retrieve_cell_value(table, row_key) == new_value + assert self._retrieve_cell_value(table, row_key2) == new_value2 + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_mutations_batcher_no_flush(self, client, table, temp_rows): + """test with no flush requirements met""" + from google.cloud.bigtable.data.mutations import RowMutationEntry + + new_value = uuid.uuid4().hex.encode() + start_value = b"unchanged" + row_key, mutation = self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation = RowMutationEntry(row_key, [mutation]) + row_key2, mutation2 = self._create_row_and_mutation( + table, temp_rows, start_value=start_value, new_value=new_value + ) + bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) + size_limit = bulk_mutation.size() + bulk_mutation2.size() + 1 + with table.mutations_batcher( + flush_limit_bytes=size_limit, flush_limit_mutation_count=3, flush_interval=1 + ) as batcher: + batcher.append(bulk_mutation) + assert len(batcher._staged_entries) == 1 + batcher.append(bulk_mutation2) + assert len(batcher._flush_jobs) == 0 + CrossSync._Sync_Impl.yield_to_event_loop() + assert len(batcher._staged_entries) == 2 + assert len(batcher._flush_jobs) == 0 + assert self._retrieve_cell_value(table, row_key) == start_value + assert self._retrieve_cell_value(table, row_key2) == start_value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,increment,expected", + [ + (0, 0, 0), + (0, 1, 1), + (0, -1, -1), + (1, 0, 1), + (0, -100, -100), + (0, 3000, 3000), + (10, 4, 14), + (_MAX_INCREMENT_VALUE, -_MAX_INCREMENT_VALUE, 0), + (_MAX_INCREMENT_VALUE, 2, -_MAX_INCREMENT_VALUE), + (-_MAX_INCREMENT_VALUE, -2, _MAX_INCREMENT_VALUE), + ], + ) + def test_read_modify_write_row_increment( + self, client, table, temp_rows, start, increment, expected + ): + """test read_modify_write_row""" + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + rule = IncrementRule(family, qualifier, increment) + result = table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert int(result[0]) == expected + assert self._retrieve_cell_value(table, row_key) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start,append,expected", + [ + (b"", b"", b""), + ("", "", b""), + (b"abc", b"123", b"abc123"), + (b"abc", "123", b"abc123"), + ("", b"1", b"1"), + (b"abc", "", b"abc"), + (b"hello", b"world", b"helloworld"), + ], + ) + def test_read_modify_write_row_append( + self, client, table, temp_rows, start, append, expected + ): + """test read_modify_write_row""" + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row(row_key, value=start, family=family, qualifier=qualifier) + rule = AppendValueRule(family, qualifier, append) + result = table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert len(result) == 1 + assert result[0].family == family + assert result[0].qualifier == qualifier + assert result[0].value == expected + assert self._retrieve_cell_value(table, row_key) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_read_modify_write_row_chained(self, client, table, temp_rows): + """test read_modify_write_row with multiple rules""" + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + start_amount = 1 + increment_amount = 10 + temp_rows.add_row( + row_key, value=start_amount, family=family, qualifier=qualifier + ) + rule = [ + IncrementRule(family, qualifier, increment_amount), + AppendValueRule(family, qualifier, "hello"), + AppendValueRule(family, qualifier, "world"), + AppendValueRule(family, qualifier, "!"), + ] + result = table.read_modify_write_row(row_key, rule) + assert result.row_key == row_key + assert result[0].family == family + assert result[0].qualifier == qualifier + assert ( + result[0].value + == (start_amount + increment_amount).to_bytes(8, "big", signed=True) + + b"helloworld!" + ) + assert self._retrieve_cell_value(table, row_key) == result[0].value + + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @pytest.mark.parametrize( + "start_val,predicate_range,expected_result", + [(1, (0, 2), True), (-1, (0, 2), False)], + ) + def test_check_and_mutate( + self, client, table, temp_rows, start_val, predicate_range, expected_result + ): + """test that check_and_mutate_row works applies the right mutations, and returns the right result""" + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable.data.row_filters import ValueRangeFilter + + row_key = b"test-row-key" + family = TEST_FAMILY + qualifier = b"test-qualifier" + temp_rows.add_row(row_key, value=start_val, family=family, qualifier=qualifier) + false_mutation_value = b"false-mutation-value" + false_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=false_mutation_value + ) + true_mutation_value = b"true-mutation-value" + true_mutation = SetCell( + family=TEST_FAMILY, qualifier=qualifier, new_value=true_mutation_value + ) + predicate = ValueRangeFilter(predicate_range[0], predicate_range[1]) + result = table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + assert result == expected_result + expected_value = ( + true_mutation_value if expected_result else false_mutation_value + ) + assert self._retrieve_cell_value(table, row_key) == expected_value + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + def test_check_and_mutate_empty_request(self, client, table): + """check_and_mutate with no true or fale mutations should raise an error""" + from google.api_core import exceptions + + with pytest.raises(exceptions.InvalidArgument) as e: + table.check_and_mutate_row( + b"row_key", None, true_case_mutations=None, false_case_mutations=None + ) + assert "No mutations provided" in str(e.value) + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_stream(self, table, temp_rows): + """Ensure that the read_rows_stream method works""" + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + generator = table.read_rows_stream({}) + first_row = generator.__next__() + second_row = generator.__next__() + assert first_row.row_key == b"row_key_1" + assert second_row.row_key == b"row_key_2" + with pytest.raises(CrossSync._Sync_Impl.StopIteration): + generator.__next__() + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows(self, table, temp_rows): + """Ensure that the read_rows method works""" + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + row_list = table.read_rows({}) + assert len(row_list) == 2 + assert row_list[0].row_key == b"row_key_1" + assert row_list[1].row_key == b"row_key_2" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_sharded_simple(self, table, temp_rows): + """Test read rows sharded with two queries""" + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"]) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"]) + row_list = table.read_rows_sharded([query1, query2]) + assert len(row_list) == 4 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"b" + assert row_list[3].row_key == b"d" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_sharded_from_sample(self, table, temp_rows): + """Test end-to-end sharding""" + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.read_rows_query import RowRange + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + table_shard_keys = table.sample_row_keys() + query = ReadRowsQuery(row_ranges=[RowRange(start_key=b"b", end_key=b"z")]) + shard_queries = query.shard(table_shard_keys) + row_list = table.read_rows_sharded(shard_queries) + assert len(row_list) == 3 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + assert row_list[2].row_key == b"d" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_sharded_filters_limits(self, table, temp_rows): + """Test read rows sharded with filters and limits""" + from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + label_filter1 = ApplyLabelFilter("first") + label_filter2 = ApplyLabelFilter("second") + query1 = ReadRowsQuery(row_keys=[b"a", b"c"], limit=1, row_filter=label_filter1) + query2 = ReadRowsQuery(row_keys=[b"b", b"d"], row_filter=label_filter2) + row_list = table.read_rows_sharded([query1, query2]) + assert len(row_list) == 3 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"b" + assert row_list[2].row_key == b"d" + assert row_list[0][0].labels == ["first"] + assert row_list[1][0].labels == ["second"] + assert row_list[2][0].labels == ["second"] + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_range_query(self, table, temp_rows): + """Ensure that the read_rows method works""" + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data import RowRange + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + query = ReadRowsQuery(row_ranges=RowRange(start_key=b"b", end_key=b"d")) + row_list = table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"b" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_single_key_query(self, table, temp_rows): + """Ensure that the read_rows method works with specified query""" + from google.cloud.bigtable.data import ReadRowsQuery + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + query = ReadRowsQuery(row_keys=[b"a", b"c"]) + row_list = table.read_rows(query) + assert len(row_list) == 2 + assert row_list[0].row_key == b"a" + assert row_list[1].row_key == b"c" + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_read_rows_with_filter(self, table, temp_rows): + """ensure filters are applied""" + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + temp_rows.add_row(b"a") + temp_rows.add_row(b"b") + temp_rows.add_row(b"c") + temp_rows.add_row(b"d") + expected_label = "test-label" + row_filter = ApplyLabelFilter(expected_label) + query = ReadRowsQuery(row_filter=row_filter) + row_list = table.read_rows(query) + assert len(row_list) == 4 + for row in row_list: + assert row[0].labels == [expected_label] + + @pytest.mark.usefixtures("table") + def test_read_rows_stream_close(self, table, temp_rows): + """Ensure that the read_rows_stream can be closed""" + from google.cloud.bigtable.data import ReadRowsQuery + + temp_rows.add_row(b"row_key_1") + temp_rows.add_row(b"row_key_2") + query = ReadRowsQuery() + generator = table.read_rows_stream(query) + first_row = generator.__next__() + assert first_row.row_key == b"row_key_1" + generator.close() + with pytest.raises(CrossSync._Sync_Impl.StopIteration): + generator.__next__() + + @pytest.mark.usefixtures("table") + def test_read_row(self, table, temp_rows): + """Test read_row (single row helper)""" + from google.cloud.bigtable.data import Row + + temp_rows.add_row(b"row_key_1", value=b"value") + row = table.read_row(b"row_key_1") + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + def test_read_row_missing(self, table): + """Test read_row when row does not exist""" + from google.api_core import exceptions + + row_key = "row_key_not_exist" + result = table.read_row(row_key) + assert result is None + with pytest.raises(exceptions.InvalidArgument) as e: + table.read_row("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + def test_read_row_w_filter(self, table, temp_rows): + """Test read_row (single row helper)""" + from google.cloud.bigtable.data import Row + from google.cloud.bigtable.data.row_filters import ApplyLabelFilter + + temp_rows.add_row(b"row_key_1", value=b"value") + expected_label = "test-label" + label_filter = ApplyLabelFilter(expected_label) + row = table.read_row(b"row_key_1", row_filter=label_filter) + assert isinstance(row, Row) + assert row.row_key == b"row_key_1" + assert row.cells[0].value == b"value" + assert row.cells[0].labels == [expected_label] + + @pytest.mark.skipif( + bool(os.environ.get(BIGTABLE_EMULATOR)), + reason="emulator doesn't raise InvalidArgument", + ) + @pytest.mark.usefixtures("table") + def test_row_exists(self, table, temp_rows): + from google.api_core import exceptions + + "Test row_exists with rows that exist and don't exist" + assert table.row_exists(b"row_key_1") is False + temp_rows.add_row(b"row_key_1") + assert table.row_exists(b"row_key_1") is True + assert table.row_exists("row_key_1") is True + assert table.row_exists(b"row_key_2") is False + assert table.row_exists("row_key_2") is False + assert table.row_exists("3") is False + temp_rows.add_row(b"3") + assert table.row_exists(b"3") is True + with pytest.raises(exceptions.InvalidArgument) as e: + table.row_exists("") + assert "Row keys must be non-empty" in str(e) + + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @pytest.mark.parametrize( + "cell_value,filter_input,expect_match", + [ + (b"abc", b"abc", True), + (b"abc", "abc", True), + (b".", ".", True), + (".*", ".*", True), + (".*", b".*", True), + ("a", ".*", False), + (b".*", b".*", True), + ("\\a", "\\a", True), + (b"\xe2\x98\x83", "☃", True), + ("☃", "☃", True), + ("\\C☃", "\\C☃", True), + (1, 1, True), + (2, 1, False), + (68, 68, True), + ("D", 68, False), + (68, "D", False), + (-1, -1, True), + (2852126720, 2852126720, True), + (-1431655766, -1431655766, True), + (-1431655766, -1, False), + ], + ) + def test_literal_value_filter( + self, table, temp_rows, cell_value, filter_input, expect_match + ): + """Literal value filter does complex escaping on re2 strings. + Make sure inputs are properly interpreted by the server""" + from google.cloud.bigtable.data.row_filters import LiteralValueFilter + from google.cloud.bigtable.data import ReadRowsQuery + + f = LiteralValueFilter(filter_input) + temp_rows.add_row(b"row_key_1", value=cell_value) + query = ReadRowsQuery(row_filter=f) + row_list = table.read_rows(query) + assert len(row_list) == bool( + expect_match + ), f"row {type(cell_value)}({cell_value}) not found with {type(filter_input)}({filter_input}) filter" diff --git a/tests/unit/data/_sync/__init__.py b/tests/unit/data/_sync/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/data/_sync/test__mutate_rows.py b/tests/unit/data/_sync/test__mutate_rows.py new file mode 100644 index 000000000..d394ff954 --- /dev/null +++ b/tests/unit/data/_sync/test__mutate_rows.py @@ -0,0 +1,310 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by CrossSync. Do not edit manually. +import pytest +from google.cloud.bigtable_v2.types import MutateRowsResponse +from google.rpc import status_pb2 +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import Forbidden +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +try: + from unittest import mock +except ImportError: + import mock + + +class TestMutateRowsOperation: + def _target_class(self): + return CrossSync._Sync_Impl._MutateRowsOperation + + def _make_one(self, *args, **kwargs): + if not args: + kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) + kwargs["table"] = kwargs.pop("table", CrossSync._Sync_Impl.Mock()) + kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) + kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) + kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) + kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) + return self._target_class()(*args, **kwargs) + + def _make_mutation(self, count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def _mock_stream(self, mutation_list, error_dict): + for idx, entry in enumerate(mutation_list): + code = error_dict.get(idx, 0) + yield MutateRowsResponse( + entries=[ + MutateRowsResponse.Entry( + index=idx, status=status_pb2.Status(code=code) + ) + ] + ) + + def _make_mock_gapic(self, mutation_list, error_dict=None): + mock_fn = CrossSync._Sync_Impl.Mock() + if error_dict is None: + error_dict = {} + mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( + mutation_list, error_dict + ) + return mock_fn + + def test_ctor(self): + """test that constructor sets all the attributes correctly""" + from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + from google.api_core.exceptions import DeadlineExceeded + from google.api_core.exceptions import Aborted + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + attempt_timeout = 0.01 + retryable_exceptions = () + instance = self._make_one( + client, + table, + entries, + operation_timeout, + attempt_timeout, + retryable_exceptions, + ) + assert client.mutate_rows.call_count == 0 + instance._gapic_fn() + assert client.mutate_rows.call_count == 1 + inner_kwargs = client.mutate_rows.call_args[1] + assert len(inner_kwargs) == 4 + assert inner_kwargs["table_name"] == table.table_name + assert inner_kwargs["app_profile_id"] == table.app_profile_id + assert inner_kwargs["retry"] is None + metadata = inner_kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert str(table.table_name) in metadata[0][1] + assert str(table.app_profile_id) in metadata[0][1] + entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries] + assert instance.mutations == entries_w_pb + assert next(instance.timeout_generator) == attempt_timeout + assert instance.is_retryable is not None + assert instance.is_retryable(DeadlineExceeded("")) is False + assert instance.is_retryable(Aborted("")) is False + assert instance.is_retryable(_MutateRowsIncomplete("")) is True + assert instance.is_retryable(RuntimeError("")) is False + assert instance.remaining_indices == list(range(len(entries))) + assert instance.errors == {} + + def test_ctor_too_many_entries(self): + """should raise an error if an operation is created with more than 100,000 entries""" + from google.cloud.bigtable.data._async._mutate_rows import ( + _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, + ) + + assert _MUTATE_ROWS_REQUEST_MUTATION_LIMIT == 100000 + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] * (_MUTATE_ROWS_REQUEST_MUTATION_LIMIT + 1) + operation_timeout = 0.05 + attempt_timeout = 0.01 + with pytest.raises(ValueError) as e: + self._make_one(client, table, entries, operation_timeout, attempt_timeout) + assert "mutate_rows requests can contain at most 100000 mutations" in str( + e.value + ) + assert "Found 100001" in str(e.value) + + def test_mutate_rows_operation(self): + """Test successful case of mutate_rows_operation""" + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + cls = self._target_class() + with mock.patch( + f"{cls.__module__}.{cls.__name__}._run_attempt", CrossSync._Sync_Impl.Mock() + ) as attempt_mock: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + assert attempt_mock.call_count == 1 + + @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) + def test_mutate_rows_attempt_exception(self, exc_type): + """exceptions raised from attempt should be raised in MutationsExceptionGroup""" + client = CrossSync._Sync_Impl.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + expected_exception = exc_type("test") + client.mutate_rows.side_effect = expected_exception + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance._run_attempt() + except Exception as e: + found_exc = e + assert client.mutate_rows.call_count == 1 + assert type(found_exc) is exc_type + assert found_exc == expected_exception + assert len(instance.errors) == 2 + assert len(instance.remaining_indices) == 0 + + @pytest.mark.parametrize("exc_type", [RuntimeError, ZeroDivisionError, Forbidden]) + def test_mutate_rows_exception(self, exc_type): + """exceptions raised from retryable should be raised in MutationsExceptionGroup""" + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedMutationEntryError + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation(), self._make_mutation()] + operation_timeout = 0.05 + expected_cause = exc_type("abort") + with mock.patch.object( + self._target_class(), "_run_attempt", CrossSync._Sync_Impl.Mock() + ) as attempt_mock: + attempt_mock.side_effect = expected_cause + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + except MutationsExceptionGroup as e: + found_exc = e + assert attempt_mock.call_count == 1 + assert len(found_exc.exceptions) == 2 + assert isinstance(found_exc.exceptions[0], FailedMutationEntryError) + assert isinstance(found_exc.exceptions[1], FailedMutationEntryError) + assert found_exc.exceptions[0].__cause__ == expected_cause + assert found_exc.exceptions[1].__cause__ == expected_cause + + @pytest.mark.parametrize("exc_type", [DeadlineExceeded, RuntimeError]) + def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): + """If an exception fails but eventually passes, it should not raise an exception""" + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] + operation_timeout = 1 + expected_cause = exc_type("retry") + num_retries = 2 + with mock.patch.object( + self._target_class(), "_run_attempt", CrossSync._Sync_Impl.Mock() + ) as attempt_mock: + attempt_mock.side_effect = [expected_cause] * num_retries + [None] + instance = self._make_one( + client, + table, + entries, + operation_timeout, + operation_timeout, + retryable_exceptions=(exc_type,), + ) + instance.start() + assert attempt_mock.call_count == num_retries + 1 + + def test_mutate_rows_incomplete_ignored(self): + """MutateRowsIncomplete exceptions should not be added to error list""" + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded + + client = mock.Mock() + table = mock.Mock() + entries = [self._make_mutation()] + operation_timeout = 0.05 + with mock.patch.object( + self._target_class(), "_run_attempt", CrossSync._Sync_Impl.Mock() + ) as attempt_mock: + attempt_mock.side_effect = _MutateRowsIncomplete("ignored") + found_exc = None + try: + instance = self._make_one( + client, table, entries, operation_timeout, operation_timeout + ) + instance.start() + except MutationsExceptionGroup as e: + found_exc = e + assert attempt_mock.call_count > 0 + assert len(found_exc.exceptions) == 1 + assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) + + def test_run_attempt_single_entry_success(self): + """Test mutating a single entry""" + mutation = self._make_mutation() + expected_timeout = 1.3 + mock_gapic_fn = self._make_mock_gapic({0: mutation}) + instance = self._make_one( + mutation_entries=[mutation], attempt_timeout=expected_timeout + ) + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + instance._run_attempt() + assert len(instance.remaining_indices) == 0 + assert mock_gapic_fn.call_count == 1 + _, kwargs = mock_gapic_fn.call_args + assert kwargs["timeout"] == expected_timeout + assert kwargs["entries"] == [mutation._to_pb()] + + def test_run_attempt_empty_request(self): + """Calling with no mutations should result in no API calls""" + mock_gapic_fn = self._make_mock_gapic([]) + instance = self._make_one(mutation_entries=[]) + instance._run_attempt() + assert mock_gapic_fn.call_count == 0 + + def test_run_attempt_partial_success_retryable(self): + """Some entries succeed, but one fails. Should report the proper index, and raise incomplete exception""" + from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete + + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() + mutations = [success_mutation, failure_mutation, success_mutation_2] + mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) + instance = self._make_one(mutation_entries=mutations) + instance.is_retryable = lambda x: True + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + with pytest.raises(_MutateRowsIncomplete): + instance._run_attempt() + assert instance.remaining_indices == [1] + assert 0 not in instance.errors + assert len(instance.errors[1]) == 1 + assert instance.errors[1][0].grpc_status_code == 300 + assert 2 not in instance.errors + + def test_run_attempt_partial_success_non_retryable(self): + """Some entries succeed, but one fails. Exception marked as non-retryable. Do not raise incomplete error""" + success_mutation = self._make_mutation() + success_mutation_2 = self._make_mutation() + failure_mutation = self._make_mutation() + mutations = [success_mutation, failure_mutation, success_mutation_2] + mock_gapic_fn = self._make_mock_gapic(mutations, error_dict={1: 300}) + instance = self._make_one(mutation_entries=mutations) + instance.is_retryable = lambda x: False + with mock.patch.object(instance, "_gapic_fn", mock_gapic_fn): + instance._run_attempt() + assert instance.remaining_indices == [] + assert 0 not in instance.errors + assert len(instance.errors[1]) == 1 + assert instance.errors[1][0].grpc_status_code == 300 + assert 2 not in instance.errors diff --git a/tests/unit/data/_sync/test__read_rows.py b/tests/unit/data/_sync/test__read_rows.py new file mode 100644 index 000000000..015f96d98 --- /dev/null +++ b/tests/unit/data/_sync/test__read_rows.py @@ -0,0 +1,357 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by CrossSync. Do not edit manually. +import pytest +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +try: + from unittest import mock +except ImportError: + import mock + + +class TestReadRowsOperation: + """ + Tests helper functions in the ReadRowsOperation class + in-depth merging logic in merge_row_response_stream and _read_rows_retryable_attempt + is tested in test_read_rows_acceptance test_client_read_rows, and conformance tests + """ + + @staticmethod + def _get_target_class(): + return CrossSync._Sync_Impl._ReadRowsOperation + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor(self): + from google.cloud.bigtable.data import ReadRowsQuery + + row_limit = 91 + query = ReadRowsQuery(limit=row_limit) + client = mock.Mock() + client.read_rows = mock.Mock() + client.read_rows.return_value = None + table = mock.Mock() + table._client = client + table.table_name = "test_table" + table.app_profile_id = "test_profile" + expected_operation_timeout = 42 + expected_request_timeout = 44 + time_gen_mock = mock.Mock() + subpath = "_async" if CrossSync._Sync_Impl.is_async else "_sync" + with mock.patch( + f"google.cloud.bigtable.data.{subpath}._read_rows._attempt_timeout_generator", + time_gen_mock, + ): + instance = self._make_one( + query, + table, + operation_timeout=expected_operation_timeout, + attempt_timeout=expected_request_timeout, + ) + assert time_gen_mock.call_count == 1 + time_gen_mock.assert_called_once_with( + expected_request_timeout, expected_operation_timeout + ) + assert instance._last_yielded_row_key is None + assert instance._remaining_count == row_limit + assert instance.operation_timeout == expected_operation_timeout + assert client.read_rows.call_count == 0 + assert instance._metadata == [ + ( + "x-goog-request-params", + "table_name=test_table&app_profile_id=test_profile", + ) + ] + assert instance.request.table_name == table.table_name + assert instance.request.app_profile_id == table.app_profile_id + assert instance.request.rows_limit == row_limit + + @pytest.mark.parametrize( + "in_keys,last_key,expected", + [ + (["b", "c", "d"], "a", ["b", "c", "d"]), + (["a", "b", "c"], "b", ["c"]), + (["a", "b", "c"], "c", []), + (["a", "b", "c"], "d", []), + (["d", "c", "b", "a"], "b", ["d", "c"]), + ], + ) + @pytest.mark.parametrize("with_range", [True, False]) + def test_revise_request_rowset_keys_with_range( + self, in_keys, last_key, expected, with_range + ): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + from google.cloud.bigtable.data.exceptions import _RowSetComplete + + in_keys = [key.encode("utf-8") for key in in_keys] + expected = [key.encode("utf-8") for key in expected] + last_key = last_key.encode("utf-8") + if with_range: + sample_range = [RowRangePB(start_key_open=last_key)] + else: + sample_range = [] + row_set = RowSetPB(row_keys=in_keys, row_ranges=sample_range) + if not with_range and expected == []: + with pytest.raises(_RowSetComplete): + self._get_target_class()._revise_request_rowset(row_set, last_key) + else: + revised = self._get_target_class()._revise_request_rowset(row_set, last_key) + assert revised.row_keys == expected + assert revised.row_ranges == sample_range + + @pytest.mark.parametrize( + "in_ranges,last_key,expected", + [ + ( + [{"start_key_open": "b", "end_key_closed": "d"}], + "a", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_closed": "b", "end_key_closed": "d"}], + "a", + [{"start_key_closed": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_open": "a", "end_key_closed": "d"}], + "b", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ( + [{"start_key_closed": "a", "end_key_open": "d"}], + "b", + [{"start_key_open": "b", "end_key_open": "d"}], + ), + ( + [{"start_key_closed": "b", "end_key_closed": "d"}], + "b", + [{"start_key_open": "b", "end_key_closed": "d"}], + ), + ([{"start_key_closed": "b", "end_key_closed": "d"}], "d", []), + ([{"start_key_closed": "b", "end_key_open": "d"}], "d", []), + ([{"start_key_closed": "b", "end_key_closed": "d"}], "e", []), + ([{"start_key_closed": "b"}], "z", [{"start_key_open": "z"}]), + ([{"start_key_closed": "b"}], "a", [{"start_key_closed": "b"}]), + ( + [{"end_key_closed": "z"}], + "a", + [{"start_key_open": "a", "end_key_closed": "z"}], + ), + ( + [{"end_key_open": "z"}], + "a", + [{"start_key_open": "a", "end_key_open": "z"}], + ), + ], + ) + @pytest.mark.parametrize("with_key", [True, False]) + def test_revise_request_rowset_ranges( + self, in_ranges, last_key, expected, with_key + ): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + from google.cloud.bigtable.data.exceptions import _RowSetComplete + + next_key = (last_key + "a").encode("utf-8") + last_key = last_key.encode("utf-8") + in_ranges = [ + RowRangePB(**{k: v.encode("utf-8") for k, v in r.items()}) + for r in in_ranges + ] + expected = [ + RowRangePB(**{k: v.encode("utf-8") for k, v in r.items()}) for r in expected + ] + if with_key: + row_keys = [next_key] + else: + row_keys = [] + row_set = RowSetPB(row_ranges=in_ranges, row_keys=row_keys) + if not with_key and expected == []: + with pytest.raises(_RowSetComplete): + self._get_target_class()._revise_request_rowset(row_set, last_key) + else: + revised = self._get_target_class()._revise_request_rowset(row_set, last_key) + assert revised.row_keys == row_keys + assert revised.row_ranges == expected + + @pytest.mark.parametrize("last_key", ["a", "b", "c"]) + def test_revise_request_full_table(self, last_key): + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + last_key = last_key.encode("utf-8") + row_set = RowSetPB() + for selected_set in [row_set, None]: + revised = self._get_target_class()._revise_request_rowset( + selected_set, last_key + ) + assert revised.row_keys == [] + assert len(revised.row_ranges) == 1 + assert revised.row_ranges[0] == RowRangePB(start_key_open=last_key) + + def test_revise_to_empty_rowset(self): + """revising to an empty rowset should raise error""" + from google.cloud.bigtable.data.exceptions import _RowSetComplete + from google.cloud.bigtable_v2.types import RowSet as RowSetPB + from google.cloud.bigtable_v2.types import RowRange as RowRangePB + + row_keys = [b"a", b"b", b"c"] + row_range = RowRangePB(end_key_open=b"c") + row_set = RowSetPB(row_keys=row_keys, row_ranges=[row_range]) + with pytest.raises(_RowSetComplete): + self._get_target_class()._revise_request_rowset(row_set, b"d") + + @pytest.mark.parametrize( + "start_limit,emit_num,expected_limit", + [ + (10, 0, 10), + (10, 1, 9), + (10, 10, 0), + (None, 10, None), + (None, 0, None), + (4, 2, 2), + ], + ) + def test_revise_limit(self, start_limit, emit_num, expected_limit): + """revise_limit should revise the request's limit field + - if limit is 0 (unlimited), it should never be revised + - if start_limit-emit_num == 0, the request should end early + - if the number emitted exceeds the new limit, an exception should + should be raised (tested in test_revise_limit_over_limit)""" + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse + + def awaitable_stream(): + def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) + + return mock_stream() + + query = ReadRowsQuery(limit=start_limit) + table = mock.Mock() + table.table_name = "table_name" + table.app_profile_id = "app_profile_id" + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + for val in instance.chunk_stream(awaitable_stream()): + pass + assert instance._remaining_count == expected_limit + + @pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)]) + def test_revise_limit_over_limit(self, start_limit, emit_num): + """Should raise runtime error if we get in state where emit_num > start_num + (unless start_num == 0, which represents unlimited)""" + from google.cloud.bigtable.data import ReadRowsQuery + from google.cloud.bigtable_v2.types import ReadRowsResponse + from google.cloud.bigtable.data.exceptions import InvalidChunk + + def awaitable_stream(): + def mock_stream(): + for i in range(emit_num): + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + ] + ) + + return mock_stream() + + query = ReadRowsQuery(limit=start_limit) + table = mock.Mock() + table.table_name = "table_name" + table.app_profile_id = "app_profile_id" + instance = self._make_one(query, table, 10, 10) + assert instance._remaining_count == start_limit + with pytest.raises(InvalidChunk) as e: + for val in instance.chunk_stream(awaitable_stream()): + pass + assert "emit count exceeds row limit" in str(e.value) + + def test_close(self): + """should be able to close a stream safely with aclose. + Closed generators should raise StopAsyncIteration on next yield""" + + def mock_stream(): + while True: + yield 1 + + with mock.patch.object( + self._get_target_class(), "_read_rows_attempt" + ) as mock_attempt: + instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) + wrapped_gen = mock_stream() + mock_attempt.return_value = wrapped_gen + gen = instance.start_operation() + gen.__next__() + gen.close() + with pytest.raises(CrossSync._Sync_Impl.StopIteration): + gen.__next__() + gen.close() + with pytest.raises(CrossSync._Sync_Impl.StopIteration): + wrapped_gen.__next__() + + def test_retryable_ignore_repeated_rows(self): + """Duplicate rows should cause an invalid chunk error""" + from google.cloud.bigtable.data.exceptions import InvalidChunk + from google.cloud.bigtable_v2.types import ReadRowsResponse + + row_key = b"duplicate" + + def mock_awaitable_stream(): + def mock_stream(): + while True: + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) + ] + ) + yield ReadRowsResponse( + chunks=[ + ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) + ] + ) + + return mock_stream() + + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + stream = self._get_target_class().chunk_stream( + instance, mock_awaitable_stream() + ) + stream.__next__() + with pytest.raises(InvalidChunk) as exc: + stream.__next__() + assert "row keys should be strictly increasing" in str(exc.value) diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py new file mode 100644 index 000000000..b67f77298 --- /dev/null +++ b/tests/unit/data/_sync/test_client.py @@ -0,0 +1,2744 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by CrossSync. Do not edit manually. +from __future__ import annotations +import grpc +import asyncio +import re +import pytest +from google.cloud.bigtable.data import mutations +from google.auth.credentials import AnonymousCredentials +from google.cloud.bigtable_v2.types import ReadRowsResponse +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.api_core import exceptions as core_exceptions +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data import TABLE_DEFAULT +from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule +from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +try: + from unittest import mock +except ImportError: + import mock +if CrossSync._Sync_Impl.is_async: + from google.api_core import grpc_helpers_async + + CrossSync._Sync_Impl.add_mapping("grpc_helpers", grpc_helpers_async) +else: + from google.api_core import grpc_helpers + + CrossSync._Sync_Impl.add_mapping("grpc_helpers", grpc_helpers) + + +@CrossSync._Sync_Impl.add_mapping_decorator("TestBigtableDataClient") +class TestBigtableDataClient: + @staticmethod + def _get_target_class(): + return CrossSync._Sync_Impl.DataClient + + @classmethod + def _make_client(cls, *args, use_emulator=True, **kwargs): + import os + + env_mask = {} + if use_emulator: + env_mask["BIGTABLE_EMULATOR_HOST"] = "localhost" + import warnings + + warnings.filterwarnings("ignore", category=RuntimeWarning) + else: + kwargs["credentials"] = kwargs.get("credentials", AnonymousCredentials()) + kwargs["project"] = kwargs.get("project", "project-id") + with mock.patch.dict(os.environ, env_mask): + return cls._get_target_class()(*args, **kwargs) + + def test_ctor(self): + expected_project = "project-id" + expected_pool_size = 11 + expected_credentials = AnonymousCredentials() + client = self._make_client( + project="project-id", + pool_size=expected_pool_size, + credentials=expected_credentials, + use_emulator=False, + ) + CrossSync._Sync_Impl.yield_to_event_loop() + assert client.project == expected_project + assert len(client.transport._grpc_channel._pool) == expected_pool_size + assert not client._active_instances + assert len(client._channel_refresh_tasks) == expected_pool_size + assert client.transport._credentials == expected_credentials + client.close() + + def test_ctor_super_inits(self): + from google.cloud.client import ClientWithProject + from google.api_core import client_options as client_options_lib + from google.cloud.bigtable import __version__ as bigtable_version + + project = "project-id" + pool_size = 11 + credentials = AnonymousCredentials() + client_options = {"api_endpoint": "foo.bar:1234"} + options_parsed = client_options_lib.from_dict(client_options) + asyncio_portion = "-async" if CrossSync._Sync_Impl.is_async else "" + transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" + with mock.patch.object( + CrossSync._Sync_Impl.GapicClient, "__init__" + ) as bigtable_client_init: + bigtable_client_init.return_value = None + with mock.patch.object( + ClientWithProject, "__init__" + ) as client_project_init: + client_project_init.return_value = None + try: + self._make_client( + project=project, + pool_size=pool_size, + credentials=credentials, + client_options=options_parsed, + use_emulator=False, + ) + except AttributeError: + pass + assert bigtable_client_init.call_count == 1 + kwargs = bigtable_client_init.call_args[1] + assert kwargs["transport"] == transport_str + assert kwargs["credentials"] == credentials + assert kwargs["client_options"] == options_parsed + assert client_project_init.call_count == 1 + kwargs = client_project_init.call_args[1] + assert kwargs["project"] == project + assert kwargs["credentials"] == credentials + assert kwargs["client_options"] == options_parsed + + def test_ctor_dict_options(self): + from google.api_core.client_options import ClientOptions + + client_options = {"api_endpoint": "foo.bar:1234"} + with mock.patch.object( + CrossSync._Sync_Impl.GapicClient, "__init__" + ) as bigtable_client_init: + try: + self._make_client(client_options=client_options) + except TypeError: + pass + bigtable_client_init.assert_called_once() + kwargs = bigtable_client_init.call_args[1] + called_options = kwargs["client_options"] + assert called_options.api_endpoint == "foo.bar:1234" + assert isinstance(called_options, ClientOptions) + with mock.patch.object( + self._get_target_class(), "_start_background_channel_refresh" + ) as start_background_refresh: + client = self._make_client( + client_options=client_options, use_emulator=False + ) + start_background_refresh.assert_called_once() + client.close() + + def test_veneer_grpc_headers(self): + client_component = "data-async" if CrossSync._Sync_Impl.is_async else "data" + VENEER_HEADER_REGEX = re.compile( + "gapic\\/[0-9]+\\.[\\w.-]+ gax\\/[0-9]+\\.[\\w.-]+ gccl\\/[0-9]+\\.[\\w.-]+-" + + client_component + + " gl-python\\/[0-9]+\\.[\\w.-]+ grpc\\/[0-9]+\\.[\\w.-]+" + ) + if CrossSync._Sync_Impl.is_async: + patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") + else: + patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") + with patch as gapic_mock: + client = self._make_client(project="project-id") + wrapped_call_list = gapic_mock.call_args_list + assert len(wrapped_call_list) > 0 + for call in wrapped_call_list: + client_info = call.kwargs["client_info"] + assert client_info is not None, f"{call} has no client_info" + wrapped_user_agent_sorted = " ".join( + sorted(client_info.to_user_agent().split(" ")) + ) + assert VENEER_HEADER_REGEX.match( + wrapped_user_agent_sorted + ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" + client.close() + + def test_channel_pool_creation(self): + pool_size = 14 + with mock.patch.object( + CrossSync._Sync_Impl.grpc_helpers, + "create_channel", + CrossSync._Sync_Impl.Mock(), + ) as create_channel: + client = self._make_client(project="project-id", pool_size=pool_size) + assert create_channel.call_count == pool_size + client.close() + client = self._make_client(project="project-id", pool_size=pool_size) + pool_list = list(client.transport._grpc_channel._pool) + pool_set = set(client.transport._grpc_channel._pool) + assert len(pool_list) == len(pool_set) + client.close() + + def test_channel_pool_rotation(self): + pool_size = 7 + with mock.patch.object( + CrossSync._Sync_Impl.PooledChannel, "next_channel" + ) as next_channel: + client = self._make_client(project="project-id", pool_size=pool_size) + assert len(client.transport._grpc_channel._pool) == pool_size + next_channel.reset_mock() + with mock.patch.object( + type(client.transport._grpc_channel._pool[0]), "unary_unary" + ) as unary_unary: + channel_next = None + for i in range(pool_size): + channel_last = channel_next + channel_next = client.transport.grpc_channel._pool[i] + assert channel_last != channel_next + next_channel.return_value = channel_next + client.transport.ping_and_warm() + assert next_channel.call_count == i + 1 + unary_unary.assert_called_once() + unary_unary.reset_mock() + client.close() + + def test_channel_pool_replace(self): + import time + + sleep_module = asyncio if CrossSync._Sync_Impl.is_async else time + with mock.patch.object(sleep_module, "sleep"): + pool_size = 7 + client = self._make_client(project="project-id", pool_size=pool_size) + for replace_idx in range(pool_size): + start_pool = [ + channel for channel in client.transport._grpc_channel._pool + ] + grace_period = 9 + with mock.patch.object( + type(client.transport._grpc_channel._pool[-1]), "close" + ) as close: + new_channel = client.transport.create_channel() + client.transport.replace_channel( + replace_idx, grace=grace_period, new_channel=new_channel + ) + close.assert_called_once() + if CrossSync._Sync_Impl.is_async: + close.assert_called_once_with(grace=grace_period) + close.assert_awaited_once() + assert client.transport._grpc_channel._pool[replace_idx] == new_channel + for i in range(pool_size): + if i != replace_idx: + assert client.transport._grpc_channel._pool[i] == start_pool[i] + else: + assert client.transport._grpc_channel._pool[i] != start_pool[i] + client.close() + + def test__start_background_channel_refresh_tasks_exist(self): + client = self._make_client(project="project-id", use_emulator=False) + assert len(client._channel_refresh_tasks) > 0 + with mock.patch.object(asyncio, "create_task") as create_task: + client._start_background_channel_refresh() + create_task.assert_not_called() + client.close() + + @pytest.mark.parametrize("pool_size", [1, 3, 7]) + def test__start_background_channel_refresh(self, pool_size): + import concurrent.futures + + with mock.patch.object( + self._get_target_class(), + "_ping_and_warm_instances", + CrossSync._Sync_Impl.Mock(), + ) as ping_and_warm: + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + client._start_background_channel_refresh() + assert len(client._channel_refresh_tasks) == pool_size + for task in client._channel_refresh_tasks: + if CrossSync._Sync_Impl.is_async: + assert isinstance(task, asyncio.Task) + else: + assert isinstance(task, concurrent.futures.Future) + if CrossSync._Sync_Impl.is_async: + asyncio.sleep(0.1) + assert ping_and_warm.call_count == pool_size + for channel in client.transport._grpc_channel._pool: + ping_and_warm.assert_any_call(channel) + client.close() + + def test__ping_and_warm_instances(self): + """test ping and warm with mocked asyncio.gather""" + client_mock = mock.Mock() + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + with mock.patch.object( + CrossSync._Sync_Impl, "gather_partials", CrossSync._Sync_Impl.Mock() + ) as gather: + gather.side_effect = lambda partials, **kwargs: [None for _ in partials] + channel = mock.Mock() + client_mock._active_instances = [] + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel + ) + assert len(result) == 0 + assert gather.call_args.kwargs["return_exceptions"] is True + assert gather.call_args.kwargs["sync_executor"] == client_mock._executor + client_mock._active_instances = [ + (mock.Mock(), mock.Mock(), mock.Mock()) + ] * 4 + gather.reset_mock() + channel.reset_mock() + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel + ) + assert len(result) == 4 + gather.assert_called_once() + partial_list = gather.call_args.args[0] + assert len(partial_list) == 4 + if CrossSync._Sync_Impl.is_async: + gather.assert_awaited_once() + grpc_call_args = channel.unary_unary().call_args_list + for idx, (_, kwargs) in enumerate(grpc_call_args): + ( + expected_instance, + expected_table, + expected_app_profile, + ) = client_mock._active_instances[idx] + request = kwargs["request"] + assert request["name"] == expected_instance + assert request["app_profile_id"] == expected_app_profile + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert ( + metadata[0][1] + == f"name={expected_instance}&app_profile_id={expected_app_profile}" + ) + + def test__ping_and_warm_single_instance(self): + """should be able to call ping and warm with single instance""" + client_mock = mock.Mock() + client_mock._execute_ping_and_warms = ( + lambda *args: self._get_target_class()._execute_ping_and_warms( + client_mock, *args + ) + ) + with mock.patch.object( + CrossSync._Sync_Impl, "gather_partials", CrossSync._Sync_Impl.Mock() + ) as gather: + gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] + channel = mock.Mock() + client_mock._active_instances = [mock.Mock()] * 100 + test_key = ("test-instance", "test-table", "test-app-profile") + result = self._get_target_class()._ping_and_warm_instances( + client_mock, channel, test_key + ) + assert len(result) == 1 + grpc_call_args = channel.unary_unary().call_args_list + assert len(grpc_call_args) == 1 + kwargs = grpc_call_args[0][1] + request = kwargs["request"] + assert request["name"] == "test-instance" + assert request["app_profile_id"] == "test-app-profile" + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + assert ( + metadata[0][1] == "name=test-instance&app_profile_id=test-app-profile" + ) + + @pytest.mark.parametrize( + "refresh_interval, wait_time, expected_sleep", + [(0, 0, 0), (0, 1, 0), (10, 0, 10), (10, 5, 5), (10, 10, 0), (10, 15, 0)], + ) + def test__manage_channel_first_sleep( + self, refresh_interval, wait_time, expected_sleep + ): + import time + + with mock.patch.object(time, "monotonic") as monotonic: + monotonic.return_value = 0 + with mock.patch.object(CrossSync._Sync_Impl, "event_wait") as sleep: + sleep.side_effect = asyncio.CancelledError + try: + client = self._make_client(project="project-id") + client._channel_init_time = -wait_time + client._manage_channel(0, refresh_interval, refresh_interval) + except asyncio.CancelledError: + pass + sleep.assert_called_once() + call_time = sleep.call_args[0][1] + assert ( + abs(call_time - expected_sleep) < 0.1 + ), f"refresh_interval: {refresh_interval}, wait_time: {wait_time}, expected_sleep: {expected_sleep}" + client.close() + + def test__manage_channel_ping_and_warm(self): + """_manage channel should call ping and warm internally""" + import time + import threading + + client_mock = mock.Mock() + client_mock._is_closed.is_set.return_value = False + client_mock._channel_init_time = time.monotonic() + channel_list = [mock.Mock(), mock.Mock()] + client_mock.transport.channels = channel_list + new_channel = mock.Mock() + client_mock.transport.grpc_channel._create_channel.return_value = new_channel + sleep_tuple = ( + (asyncio, "sleep") + if CrossSync._Sync_Impl.is_async + else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple): + client_mock.transport.replace_channel.side_effect = asyncio.CancelledError + ping_and_warm = ( + client_mock._ping_and_warm_instances + ) = CrossSync._Sync_Impl.Mock() + try: + channel_idx = 1 + self._get_target_class()._manage_channel(client_mock, channel_idx, 10) + except asyncio.CancelledError: + pass + assert ping_and_warm.call_count == 2 + assert client_mock.transport.replace_channel.call_count == 1 + old_channel = channel_list[channel_idx] + assert old_channel != new_channel + called_with = [call[0][0] for call in ping_and_warm.call_args_list] + assert old_channel in called_with + assert new_channel in called_with + ping_and_warm.reset_mock() + try: + self._get_target_class()._manage_channel(client_mock, 0, 0, 0) + except asyncio.CancelledError: + pass + ping_and_warm.assert_called_once_with(new_channel) + + @pytest.mark.parametrize( + "refresh_interval, num_cycles, expected_sleep", + [(None, 1, 60 * 35), (10, 10, 100), (10, 1, 10)], + ) + def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sleep): + import time + import random + import threading + + channel_idx = 1 + with mock.patch.object(random, "uniform") as uniform: + uniform.side_effect = lambda min_, max_: min_ + with mock.patch.object(time, "time") as time_mock: + time_mock.return_value = 0 + sleep_tuple = ( + (asyncio, "sleep") + if CrossSync._Sync_Impl.is_async + else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: + sleep.side_effect = [None for i in range(num_cycles - 1)] + [ + asyncio.CancelledError + ] + client = self._make_client(project="project-id") + with mock.patch.object(client.transport, "replace_channel"): + try: + if refresh_interval is not None: + client._manage_channel( + channel_idx, refresh_interval, refresh_interval + ) + else: + client._manage_channel(channel_idx) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + if CrossSync._Sync_Impl.is_async: + total_sleep = sum([call[0][0] for call in sleep.call_args_list]) + else: + total_sleep = sum( + [call[1]["timeout"] for call in sleep.call_args_list] + ) + assert ( + abs(total_sleep - expected_sleep) < 0.1 + ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" + client.close() + + def test__manage_channel_random(self): + import random + import threading + + sleep_tuple = ( + (asyncio, "sleep") + if CrossSync._Sync_Impl.is_async + else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: + with mock.patch.object(random, "uniform") as uniform: + uniform.return_value = 0 + try: + uniform.side_effect = asyncio.CancelledError + client = self._make_client(project="project-id", pool_size=1) + except asyncio.CancelledError: + uniform.side_effect = None + uniform.reset_mock() + sleep.reset_mock() + min_val = 200 + max_val = 205 + uniform.side_effect = lambda min_, max_: min_ + sleep.side_effect = [None, None, asyncio.CancelledError] + try: + with mock.patch.object(client.transport, "replace_channel"): + client._manage_channel(0, min_val, max_val) + except asyncio.CancelledError: + pass + assert uniform.call_count == 3 + uniform_args = [call[0] for call in uniform.call_args_list] + for found_min, found_max in uniform_args: + assert found_min == min_val + assert found_max == max_val + + @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) + def test__manage_channel_refresh(self, num_cycles): + import threading + + expected_grace = 9 + expected_refresh = 0.5 + channel_idx = 1 + grpc_lib = grpc.aio if CrossSync._Sync_Impl.is_async else grpc + new_channel = grpc_lib.insecure_channel("localhost:8080") + with mock.patch.object( + CrossSync._Sync_Impl.PooledTransport, "replace_channel" + ) as replace_channel: + sleep_tuple = ( + (asyncio, "sleep") + if CrossSync._Sync_Impl.is_async + else (threading.Event, "wait") + ) + with mock.patch.object(*sleep_tuple) as sleep: + sleep.side_effect = [None for i in range(num_cycles)] + [ + asyncio.CancelledError + ] + with mock.patch.object( + CrossSync._Sync_Impl.grpc_helpers, "create_channel" + ) as create_channel: + create_channel.return_value = new_channel + with mock.patch.object( + self._get_target_class(), "_start_background_channel_refresh" + ): + client = self._make_client( + project="project-id", use_emulator=False + ) + create_channel.reset_mock() + try: + client._manage_channel( + channel_idx, + refresh_interval_min=expected_refresh, + refresh_interval_max=expected_refresh, + grace_period=expected_grace, + ) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + 1 + assert create_channel.call_count == num_cycles + assert replace_channel.call_count == num_cycles + for call in replace_channel.call_args_list: + args, kwargs = call + assert args[0] == channel_idx + assert kwargs["grace"] == expected_grace + assert kwargs["new_channel"] == new_channel + client.close() + + def test__register_instance(self): + """test instance registration""" + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: f"prefix/{b}" + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_tasks = [] + client_mock._start_background_channel_refresh.side_effect = ( + lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + ) + mock_channels = [mock.Mock() for i in range(5)] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() + table_mock = mock.Mock() + self._get_target_class()._register_instance( + client_mock, "instance-1", table_mock + ) + assert client_mock._start_background_channel_refresh.call_count == 1 + expected_key = ( + "prefix/instance-1", + table_mock.table_name, + table_mock.app_profile_id, + ) + assert len(active_instances) == 1 + assert expected_key == tuple(list(active_instances)[0]) + assert len(instance_owners) == 1 + assert expected_key == tuple(list(instance_owners)[0]) + assert client_mock._channel_refresh_tasks + table_mock2 = mock.Mock() + self._get_target_class()._register_instance( + client_mock, "instance-2", table_mock2 + ) + assert client_mock._start_background_channel_refresh.call_count == 1 + assert client_mock._ping_and_warm_instances.call_count == len(mock_channels) + for channel in mock_channels: + assert channel in [ + call[0][0] + for call in client_mock._ping_and_warm_instances.call_args_list + ] + assert len(active_instances) == 2 + assert len(instance_owners) == 2 + expected_key2 = ( + "prefix/instance-2", + table_mock2.table_name, + table_mock2.app_profile_id, + ) + assert any( + [ + expected_key2 == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) + assert any( + [ + expected_key2 == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) + + @pytest.mark.parametrize( + "insert_instances,expected_active,expected_owner_keys", + [ + ([("i", "t", None)], [("i", "t", None)], [("i", "t", None)]), + ([("i", "t", "p")], [("i", "t", "p")], [("i", "t", "p")]), + ([("1", "t", "p"), ("1", "t", "p")], [("1", "t", "p")], [("1", "t", "p")]), + ( + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + [("1", "t", "p"), ("2", "t", "p")], + ), + ], + ) + def test__register_instance_state( + self, insert_instances, expected_active, expected_owner_keys + ): + """test that active_instances and instance_owners are updated as expected""" + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: b + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_tasks = [] + client_mock._start_background_channel_refresh.side_effect = ( + lambda: client_mock._channel_refresh_tasks.append(mock.Mock) + ) + mock_channels = [mock.Mock() for i in range(5)] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() + table_mock = mock.Mock() + for instance, table, profile in insert_instances: + table_mock.table_name = table + table_mock.app_profile_id = profile + self._get_target_class()._register_instance( + client_mock, instance, table_mock + ) + assert len(active_instances) == len(expected_active) + assert len(instance_owners) == len(expected_owner_keys) + for expected in expected_active: + assert any( + [ + expected == tuple(list(active_instances)[i]) + for i in range(len(active_instances)) + ] + ) + for expected in expected_owner_keys: + assert any( + [ + expected == tuple(list(instance_owners)[i]) + for i in range(len(instance_owners)) + ] + ) + + def test__remove_instance_registration(self): + client = self._make_client(project="project-id") + table = mock.Mock() + client._register_instance("instance-1", table) + client._register_instance("instance-2", table) + assert len(client._active_instances) == 2 + assert len(client._instance_owners.keys()) == 2 + instance_1_path = client._gapic_client.instance_path( + client.project, "instance-1" + ) + instance_1_key = (instance_1_path, table.table_name, table.app_profile_id) + instance_2_path = client._gapic_client.instance_path( + client.project, "instance-2" + ) + instance_2_key = (instance_2_path, table.table_name, table.app_profile_id) + assert len(client._instance_owners[instance_1_key]) == 1 + assert list(client._instance_owners[instance_1_key])[0] == id(table) + assert len(client._instance_owners[instance_2_key]) == 1 + assert list(client._instance_owners[instance_2_key])[0] == id(table) + success = client._remove_instance_registration("instance-1", table) + assert success + assert len(client._active_instances) == 1 + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 1 + assert client._active_instances == {instance_2_key} + success = client._remove_instance_registration("fake-key", table) + assert not success + assert len(client._active_instances) == 1 + client.close() + + def test__multiple_table_registration(self): + """registering with multiple tables with the same key should + add multiple owners to instance_owners, but only keep one copy + of shared key in active_instances""" + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + with self._make_client(project="project-id") as client: + with client.get_table("instance_1", "table_1") as table_1: + instance_1_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 1 + assert len(client._active_instances) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + with client.get_table("instance_1", "table_1") as table_2: + assert table_2._register_instance_future is not None + if not CrossSync._Sync_Impl.is_async: + table_2._register_instance_future.result() + assert len(client._instance_owners[instance_1_key]) == 2 + assert len(client._active_instances) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + with client.get_table("instance_1", "table_3") as table_3: + assert table_3._register_instance_future is not None + if not CrossSync._Sync_Impl.is_async: + table_3._register_instance_future.result() + instance_3_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_3_key = _WarmedInstanceKey( + instance_3_path, table_3.table_name, table_3.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 2 + assert len(client._instance_owners[instance_3_key]) == 1 + assert len(client._active_instances) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_1_key] + assert id(table_3) in client._instance_owners[instance_3_key] + assert len(client._active_instances) == 1 + assert instance_1_key in client._active_instances + assert id(table_2) not in client._instance_owners[instance_1_key] + assert len(client._active_instances) == 0 + assert instance_1_key not in client._active_instances + assert len(client._instance_owners[instance_1_key]) == 0 + + def test__multiple_instance_registration(self): + """registering with multiple instance keys should update the key + in instance_owners and active_instances""" + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + with self._make_client(project="project-id") as client: + with client.get_table("instance_1", "table_1") as table_1: + assert table_1._register_instance_future is not None + if not CrossSync._Sync_Impl.is_async: + table_1._register_instance_future.result() + with client.get_table("instance_2", "table_2") as table_2: + assert table_2._register_instance_future is not None + if not CrossSync._Sync_Impl.is_async: + table_2._register_instance_future.result() + instance_1_path = client._gapic_client.instance_path( + client.project, "instance_1" + ) + instance_1_key = _WarmedInstanceKey( + instance_1_path, table_1.table_name, table_1.app_profile_id + ) + instance_2_path = client._gapic_client.instance_path( + client.project, "instance_2" + ) + instance_2_key = _WarmedInstanceKey( + instance_2_path, table_2.table_name, table_2.app_profile_id + ) + assert len(client._instance_owners[instance_1_key]) == 1 + assert len(client._instance_owners[instance_2_key]) == 1 + assert len(client._active_instances) == 2 + assert id(table_1) in client._instance_owners[instance_1_key] + assert id(table_2) in client._instance_owners[instance_2_key] + assert len(client._active_instances) == 1 + assert instance_1_key in client._active_instances + assert len(client._instance_owners[instance_2_key]) == 0 + assert len(client._instance_owners[instance_1_key]) == 1 + assert id(table_1) in client._instance_owners[instance_1_key] + assert len(client._active_instances) == 0 + assert len(client._instance_owners[instance_1_key]) == 0 + assert len(client._instance_owners[instance_2_key]) == 0 + + def test_get_table(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + client = self._make_client(project="project-id") + assert not client._active_instances + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + table = client.get_table( + expected_instance_id, expected_table_id, expected_app_profile_id + ) + CrossSync._Sync_Impl.yield_to_event_loop() + assert isinstance(table, CrossSync._Sync_Impl.TestTable._get_target_class()) + assert table.table_id == expected_table_id + assert ( + table.table_name + == f"projects/{client.project}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert table.instance_id == expected_instance_id + assert ( + table.instance_name + == f"projects/{client.project}/instances/{expected_instance_id}" + ) + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + client.close() + + def test_get_table_arg_passthrough(self): + """All arguments passed in get_table should be sent to constructor""" + with self._make_client(project="project-id") as client: + with mock.patch.object( + CrossSync._Sync_Impl.TestTable._get_target_class(), "__init__" + ) as mock_constructor: + mock_constructor.return_value = None + assert not client._active_instances + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_args = (1, "test", {"test": 2}) + expected_kwargs = {"hello": "world", "test": 2} + client.get_table( + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, + ) + mock_constructor.assert_called_once_with( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, + ) + + def test_get_table_context_manager(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_project_id = "project-id" + with mock.patch.object( + CrossSync._Sync_Impl.TestTable._get_target_class(), "close" + ) as close_mock: + with self._make_client(project=expected_project_id) as client: + with client.get_table( + expected_instance_id, expected_table_id, expected_app_profile_id + ) as table: + CrossSync._Sync_Impl.yield_to_event_loop() + assert isinstance( + table, CrossSync._Sync_Impl.TestTable._get_target_class() + ) + assert table.table_id == expected_table_id + assert ( + table.table_name + == f"projects/{expected_project_id}/instances/{expected_instance_id}/tables/{expected_table_id}" + ) + assert table.instance_id == expected_instance_id + assert ( + table.instance_name + == f"projects/{expected_project_id}/instances/{expected_instance_id}" + ) + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + assert close_mock.call_count == 1 + + def test_multiple_pool_sizes(self): + pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] + for pool_size in pool_sizes: + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client._channel_refresh_tasks) == pool_size + client_duplicate = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client_duplicate._channel_refresh_tasks) == pool_size + assert str(pool_size) in str(client.transport) + client.close() + client_duplicate.close() + + def test_close(self): + pool_size = 7 + client = self._make_client( + project="project-id", pool_size=pool_size, use_emulator=False + ) + assert len(client._channel_refresh_tasks) == pool_size + tasks_list = list(client._channel_refresh_tasks) + for task in client._channel_refresh_tasks: + assert not task.done() + with mock.patch.object( + CrossSync._Sync_Impl.PooledTransport, "close", CrossSync._Sync_Impl.Mock() + ) as close_mock: + client.close() + close_mock.assert_called_once() + if CrossSync._Sync_Impl.is_async: + close_mock.assert_awaited() + for task in tasks_list: + assert task.done() + + def test_close_with_timeout(self): + pool_size = 7 + expected_timeout = 19 + client = self._make_client(project="project-id", pool_size=pool_size) + tasks = list(client._channel_refresh_tasks) + with mock.patch.object( + CrossSync._Sync_Impl, "wait", CrossSync._Sync_Impl.Mock() + ) as wait_for_mock: + client.close(timeout=expected_timeout) + wait_for_mock.assert_called_once() + if CrossSync._Sync_Impl.is_async: + wait_for_mock.assert_awaited() + assert wait_for_mock.call_args[1]["timeout"] == expected_timeout + client._channel_refresh_tasks = tasks + client.close() + + def test_context_manager(self): + close_mock = CrossSync._Sync_Impl.Mock() + true_close = None + with self._make_client(project="project-id") as client: + true_close = client.close() + client.close = close_mock + for task in client._channel_refresh_tasks: + assert not task.done() + assert client.project == "project-id" + assert client._active_instances == set() + close_mock.assert_not_called() + close_mock.assert_called_once() + if CrossSync._Sync_Impl.is_async: + close_mock.assert_awaited() + true_close + + +@CrossSync._Sync_Impl.add_mapping_decorator("TestTable") +class TestTable: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + @staticmethod + def _get_target_class(): + return CrossSync._Sync_Impl.Table + + def test_table_ctor(self): + from google.cloud.bigtable.data._helpers import _WarmedInstanceKey + + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_operation_timeout = 123 + expected_attempt_timeout = 12 + expected_read_rows_operation_timeout = 1.5 + expected_read_rows_attempt_timeout = 0.5 + expected_mutate_rows_operation_timeout = 2.5 + expected_mutate_rows_attempt_timeout = 0.75 + client = self._make_client() + assert not client._active_instances + table = self._get_target_class()( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + default_operation_timeout=expected_operation_timeout, + default_attempt_timeout=expected_attempt_timeout, + default_read_rows_operation_timeout=expected_read_rows_operation_timeout, + default_read_rows_attempt_timeout=expected_read_rows_attempt_timeout, + default_mutate_rows_operation_timeout=expected_mutate_rows_operation_timeout, + default_mutate_rows_attempt_timeout=expected_mutate_rows_attempt_timeout, + ) + CrossSync._Sync_Impl.yield_to_event_loop() + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id == expected_app_profile_id + assert table.client is client + instance_key = _WarmedInstanceKey( + table.instance_name, table.table_name, table.app_profile_id + ) + assert instance_key in client._active_instances + assert client._instance_owners[instance_key] == {id(table)} + assert table.default_operation_timeout == expected_operation_timeout + assert table.default_attempt_timeout == expected_attempt_timeout + assert ( + table.default_read_rows_operation_timeout + == expected_read_rows_operation_timeout + ) + assert ( + table.default_read_rows_attempt_timeout + == expected_read_rows_attempt_timeout + ) + assert ( + table.default_mutate_rows_operation_timeout + == expected_mutate_rows_operation_timeout + ) + assert ( + table.default_mutate_rows_attempt_timeout + == expected_mutate_rows_attempt_timeout + ) + table._register_instance_future + assert table._register_instance_future.done() + assert not table._register_instance_future.cancelled() + assert table._register_instance_future.exception() is None + client.close() + + def test_table_ctor_defaults(self): + """should provide default timeout values and app_profile_id""" + expected_table_id = "table-id" + expected_instance_id = "instance-id" + client = self._make_client() + assert not client._active_instances + table = self._get_target_class()( + client, expected_instance_id, expected_table_id + ) + CrossSync._Sync_Impl.yield_to_event_loop() + assert table.table_id == expected_table_id + assert table.instance_id == expected_instance_id + assert table.app_profile_id is None + assert table.client is client + assert table.default_operation_timeout == 60 + assert table.default_read_rows_operation_timeout == 600 + assert table.default_mutate_rows_operation_timeout == 600 + assert table.default_attempt_timeout == 20 + assert table.default_read_rows_attempt_timeout == 20 + assert table.default_mutate_rows_attempt_timeout == 60 + client.close() + + def test_table_ctor_invalid_timeout_values(self): + """bad timeout values should raise ValueError""" + client = self._make_client() + timeout_pairs = [ + ("default_operation_timeout", "default_attempt_timeout"), + ( + "default_read_rows_operation_timeout", + "default_read_rows_attempt_timeout", + ), + ( + "default_mutate_rows_operation_timeout", + "default_mutate_rows_attempt_timeout", + ), + ] + for operation_timeout, attempt_timeout in timeout_pairs: + with pytest.raises(ValueError) as e: + self._get_target_class()(client, "", "", **{attempt_timeout: -1}) + assert "attempt_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._get_target_class()(client, "", "", **{operation_timeout: -1}) + assert "operation_timeout must be greater than 0" in str(e.value) + client.close() + + @pytest.mark.parametrize( + "fn_name,fn_args,is_stream,extra_retryables", + [ + ("read_rows_stream", (ReadRowsQuery(),), True, ()), + ("read_rows", (ReadRowsQuery(),), True, ()), + ("read_row", (b"row_key",), True, ()), + ("read_rows_sharded", ([ReadRowsQuery()],), True, ()), + ("row_exists", (b"row_key",), True, ()), + ("sample_row_keys", (), False, ()), + ("mutate_row", (b"row_key", [mock.Mock()]), False, ()), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + False, + (_MutateRowsIncomplete,), + ), + ], + ) + @pytest.mark.parametrize( + "input_retryables,expected_retryables", + [ + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), + ], + ) + def test_customizable_retryable_errors( + self, + input_retryables, + expected_retryables, + fn_name, + fn_args, + is_stream, + extra_retryables, + ): + """Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer.""" + retry_fn = "retry_target" + if is_stream: + retry_fn += "_stream" + if CrossSync._Sync_Impl.is_async: + retry_fn = f"CrossSync.{retry_fn}" + else: + retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" + with mock.patch( + f"google.cloud.bigtable.data._sync.cross_sync.{retry_fn}" + ) as retry_fn_mock: + with self._make_client() as client: + table = client.get_table("instance-id", "table-id") + expected_predicate = expected_retryables.__contains__ + retry_fn_mock.side_effect = RuntimeError("stop early") + with mock.patch( + "google.api_core.retry.if_exception_type" + ) as predicate_builder_mock: + predicate_builder_mock.return_value = expected_predicate + with pytest.raises(Exception): + test_fn = table.__getattribute__(fn_name) + test_fn(*fn_args, retryable_errors=input_retryables) + predicate_builder_mock.assert_called_once_with( + *expected_retryables, *extra_retryables + ) + retry_call_args = retry_fn_mock.call_args_list[0].args + assert retry_call_args[1] is expected_predicate + + @pytest.mark.parametrize( + "fn_name,fn_args,gapic_fn", + [ + ("read_rows_stream", (ReadRowsQuery(),), "read_rows"), + ("read_rows", (ReadRowsQuery(),), "read_rows"), + ("read_row", (b"row_key",), "read_rows"), + ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), + ("row_exists", (b"row_key",), "read_rows"), + ("sample_row_keys", (), "sample_row_keys"), + ("mutate_row", (b"row_key", [mock.Mock()]), "mutate_row"), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + "mutate_rows", + ), + ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), + ( + "read_modify_write_row", + (b"row_key", mock.Mock()), + "read_modify_write_row", + ), + ], + ) + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): + """check that all requests attach proper metadata headers""" + profile = "profile" if include_app_profile else None + with mock.patch.object( + CrossSync._Sync_Impl.GapicClient, gapic_fn, CrossSync._Sync_Impl.Mock() + ) as gapic_mock: + gapic_mock.side_effect = RuntimeError("stop early") + with self._make_client() as client: + table = self._get_target_class()( + client, "instance-id", "table-id", profile + ) + try: + test_fn = table.__getattribute__(fn_name) + maybe_stream = test_fn(*fn_args) + [i for i in maybe_stream] + except Exception: + pass + kwargs = gapic_mock.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata + + +@CrossSync._Sync_Impl.add_mapping_decorator("TestReadRows") +class TestReadRows: + """ + Tests for table.read_rows and related methods. + """ + + @staticmethod + def _get_operation_class(): + return CrossSync._Sync_Impl._ReadRowsOperation + + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + def _make_table(self, *args, **kwargs): + client_mock = mock.Mock() + client_mock._register_instance.side_effect = ( + lambda *args, **kwargs: CrossSync._Sync_Impl.yield_to_event_loop() + ) + client_mock._remove_instance_registration.side_effect = ( + lambda *args, **kwargs: CrossSync._Sync_Impl.yield_to_event_loop() + ) + kwargs["instance_id"] = kwargs.get( + "instance_id", args[0] if args else "instance" + ) + kwargs["table_id"] = kwargs.get( + "table_id", args[1] if len(args) > 1 else "table" + ) + client_mock._gapic_client.table_path.return_value = kwargs["table_id"] + client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] + return CrossSync._Sync_Impl.TestTable._get_target_class()( + client_mock, *args, **kwargs + ) + + def _make_stats(self): + from google.cloud.bigtable_v2.types import RequestStats + from google.cloud.bigtable_v2.types import FullReadStatsView + from google.cloud.bigtable_v2.types import ReadIterationStats + + return RequestStats( + full_read_stats_view=FullReadStatsView( + read_iteration_stats=ReadIterationStats( + rows_seen_count=1, + rows_returned_count=2, + cells_seen_count=3, + cells_returned_count=4, + ) + ) + ) + + @staticmethod + def _make_chunk(*args, **kwargs): + from google.cloud.bigtable_v2 import ReadRowsResponse + + kwargs["row_key"] = kwargs.get("row_key", b"row_key") + kwargs["family_name"] = kwargs.get("family_name", "family_name") + kwargs["qualifier"] = kwargs.get("qualifier", b"qualifier") + kwargs["value"] = kwargs.get("value", b"value") + kwargs["commit_row"] = kwargs.get("commit_row", True) + return ReadRowsResponse.CellChunk(*args, **kwargs) + + @staticmethod + def _make_gapic_stream( + chunk_list: list[ReadRowsResponse.CellChunk | Exception], sleep_time=0 + ): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list, sleep_time): + self.chunk_list = chunk_list + self.idx = -1 + self.sleep_time = sleep_time + + def __iter__(self): + return self + + def __next__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + if sleep_time: + CrossSync._Sync_Impl.sleep(self.sleep_time) + chunk = self.chunk_list[self.idx] + if isinstance(chunk, Exception): + raise chunk + else: + return ReadRowsResponse(chunks=[chunk]) + raise CrossSync._Sync_Impl.StopIteration + + def cancel(self): + pass + + return mock_stream(chunk_list, sleep_time) + + def execute_fn(self, table, *args, **kwargs): + return table.read_rows(*args, **kwargs) + + def test_read_rows(self): + query = ReadRowsQuery() + chunks = [ + self._make_chunk(row_key=b"test_1"), + self._make_chunk(row_key=b"test_2"), + ] + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + results = self.execute_fn(table, query, operation_timeout=3) + assert len(results) == 2 + assert results[0].row_key == b"test_1" + assert results[1].row_key == b"test_2" + + def test_read_rows_stream(self): + query = ReadRowsQuery() + chunks = [ + self._make_chunk(row_key=b"test_1"), + self._make_chunk(row_key=b"test_2"), + ] + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + gen = table.read_rows_stream(query, operation_timeout=3) + results = [row for row in gen] + assert len(results) == 2 + assert results[0].row_key == b"test_1" + assert results[1].row_key == b"test_2" + + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_read_rows_query_matches_request(self, include_app_profile): + from google.cloud.bigtable.data import RowRange + from google.cloud.bigtable.data.row_filters import PassAllFilter + + app_profile_id = "app_profile_id" if include_app_profile else None + with self._make_table(app_profile_id=app_profile_id) as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream([]) + row_keys = [b"test_1", "test_2"] + row_ranges = RowRange("1start", "2end") + filter_ = PassAllFilter(True) + limit = 99 + query = ReadRowsQuery( + row_keys=row_keys, + row_ranges=row_ranges, + row_filter=filter_, + limit=limit, + ) + results = table.read_rows(query, operation_timeout=3) + assert len(results) == 0 + call_request = read_rows.call_args_list[0][0][0] + query_pb = query._to_pb(table) + assert call_request == query_pb + + @pytest.mark.parametrize("operation_timeout", [0.001, 0.023, 0.1]) + def test_read_rows_timeout(self, operation_timeout): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + query = ReadRowsQuery() + chunks = [self._make_chunk(row_key=b"test_1")] + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=0.15 + ) + try: + table.read_rows(query, operation_timeout=operation_timeout) + except core_exceptions.DeadlineExceeded as e: + assert ( + e.message + == f"operation_timeout of {operation_timeout:0.1f}s exceeded" + ) + + @pytest.mark.parametrize( + "per_request_t, operation_t, expected_num", + [(0.05, 0.08, 2), (0.05, 0.14, 3), (0.05, 0.24, 5)], + ) + def test_read_rows_attempt_timeout(self, per_request_t, operation_t, expected_num): + """Ensures that the attempt_timeout is respected and that the number of + requests is as expected. + + operation_timeout does not cancel the request, so we expect the number of + requests to be the ceiling of operation_timeout / attempt_timeout.""" + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + expected_last_timeout = operation_t - (expected_num - 1) * per_request_t + with mock.patch("random.uniform", side_effect=lambda a, b: 0): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks, sleep_time=per_request_t + ) + query = ReadRowsQuery() + chunks = [core_exceptions.DeadlineExceeded("mock deadline")] + try: + table.read_rows( + query, + operation_timeout=operation_t, + attempt_timeout=per_request_t, + ) + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + if expected_num == 0: + assert retry_exc is None + else: + assert type(retry_exc) is RetryExceptionGroup + assert f"{expected_num} failed attempts" in str(retry_exc) + assert len(retry_exc.exceptions) == expected_num + for sub_exc in retry_exc.exceptions: + assert sub_exc.message == "mock deadline" + assert read_rows.call_count == expected_num + for _, call_kwargs in read_rows.call_args_list[:-1]: + assert call_kwargs["timeout"] == per_request_t + assert call_kwargs["retry"] is None + assert ( + abs( + read_rows.call_args_list[-1][1]["timeout"] + - expected_last_timeout + ) + < 0.05 + ) + + @pytest.mark.parametrize( + "exc_type", + [ + core_exceptions.Aborted, + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + ], + ) + def test_read_rows_retryable_error(self, exc_type): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + try: + table.read_rows(query, operation_timeout=0.1) + except core_exceptions.DeadlineExceeded as e: + retry_exc = e.__cause__ + root_cause = retry_exc.exceptions[0] + assert type(root_cause) is exc_type + assert root_cause == expected_error + + @pytest.mark.parametrize( + "exc_type", + [ + core_exceptions.Cancelled, + core_exceptions.PreconditionFailed, + core_exceptions.NotFound, + core_exceptions.PermissionDenied, + core_exceptions.Conflict, + core_exceptions.InternalServerError, + core_exceptions.TooManyRequests, + core_exceptions.ResourceExhausted, + InvalidChunk, + ], + ) + def test_read_rows_non_retryable_error(self, exc_type): + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + try: + table.read_rows(query, operation_timeout=0.1) + except exc_type as e: + assert e == expected_error + + def test_read_rows_revise_request(self): + """Ensure that _revise_request is called between retries""" + from google.cloud.bigtable.data.exceptions import InvalidChunk + from google.cloud.bigtable_v2.types import RowSet + + return_val = RowSet() + with mock.patch.object( + self._get_operation_class(), "_revise_request_rowset" + ) as revise_rowset: + revise_rowset.return_value = return_val + with self._make_table() as table: + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + chunks + ) + row_keys = [b"test_1", b"test_2", b"test_3"] + query = ReadRowsQuery(row_keys=row_keys) + chunks = [ + self._make_chunk(row_key=b"test_1"), + core_exceptions.Aborted("mock retryable error"), + ] + try: + table.read_rows(query) + except InvalidChunk: + revise_rowset.assert_called() + first_call_kwargs = revise_rowset.call_args_list[0].kwargs + assert first_call_kwargs["row_set"] == query._to_pb(table).rows + assert first_call_kwargs["last_seen_row_key"] == b"test_1" + revised_call = read_rows.call_args_list[1].args[0] + assert revised_call.rows == return_val + + def test_read_rows_default_timeouts(self): + """Ensure that the default timeouts are set on the read rows operation when not overridden""" + operation_timeout = 8 + attempt_timeout = 4 + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: + mock_op.side_effect = RuntimeError("mock error") + with self._make_table( + default_read_rows_operation_timeout=operation_timeout, + default_read_rows_attempt_timeout=attempt_timeout, + ) as table: + try: + table.read_rows(ReadRowsQuery()) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["attempt_timeout"] == attempt_timeout + + def test_read_rows_default_timeout_override(self): + """When timeouts are passed, they overwrite default values""" + operation_timeout = 8 + attempt_timeout = 4 + with mock.patch.object(self._get_operation_class(), "__init__") as mock_op: + mock_op.side_effect = RuntimeError("mock error") + with self._make_table( + default_operation_timeout=99, default_attempt_timeout=97 + ) as table: + try: + table.read_rows( + ReadRowsQuery(), + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + ) + except RuntimeError: + pass + kwargs = mock_op.call_args_list[0].kwargs + assert kwargs["operation_timeout"] == operation_timeout + assert kwargs["attempt_timeout"] == attempt_timeout + + def test_read_row(self): + """Test reading a single row""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + row = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert row == expected_result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + + def test_read_row_w_filter(self): + """Test reading a single row with an added filter""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + expected_result = object() + read_rows.side_effect = lambda *args, **kwargs: [expected_result] + expected_op_timeout = 8 + expected_req_timeout = 4 + mock_filter = mock.Mock() + expected_filter = {"filter": "mock filter"} + mock_filter._to_dict.return_value = expected_filter + row = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + row_filter=expected_filter, + ) + assert row == expected_result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert len(args) == 1 + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + assert query.filter == expected_filter + + def test_read_row_no_response(self): + """should return None if row does not exist""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = lambda *args, **kwargs: [] + expected_op_timeout = 8 + expected_req_timeout = 4 + result = table.read_row( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert result is None + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + + @pytest.mark.parametrize( + "return_value,expected_result", + [([], False), ([object()], True), ([object(), object()], True)], + ) + def test_row_exists(self, return_value, expected_result): + """Test checking for row existence""" + with self._make_client() as client: + table = client.get_table("instance", "table") + row_key = b"test_1" + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = lambda *args, **kwargs: return_value + expected_op_timeout = 1 + expected_req_timeout = 2 + result = table.row_exists( + row_key, + operation_timeout=expected_op_timeout, + attempt_timeout=expected_req_timeout, + ) + assert expected_result == result + assert read_rows.call_count == 1 + args, kwargs = read_rows.call_args_list[0] + assert kwargs["operation_timeout"] == expected_op_timeout + assert kwargs["attempt_timeout"] == expected_req_timeout + assert isinstance(args[0], ReadRowsQuery) + expected_filter = { + "chain": { + "filters": [ + {"cells_per_row_limit_filter": 1}, + {"strip_value_transformer": True}, + ] + } + } + query = args[0] + assert query.row_keys == [row_key] + assert query.row_ranges == [] + assert query.limit == 1 + assert query.filter._to_dict() == expected_filter + + +class TestReadRowsSharded: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + def test_read_rows_sharded_empty_query(self): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as exc: + table.read_rows_sharded([]) + assert "empty sharded_query" in str(exc.value) + + def test_read_rows_sharded_multiple_queries(self): + """Test with multiple queries. Should return results from both""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, "read_rows" + ) as read_rows: + read_rows.side_effect = lambda *args, **kwargs: CrossSync._Sync_Impl.TestReadRows._make_gapic_stream( + [ + CrossSync._Sync_Impl.TestReadRows._make_chunk(row_key=k) + for k in args[0].rows.row_keys + ] + ) + query_1 = ReadRowsQuery(b"test_1") + query_2 = ReadRowsQuery(b"test_2") + result = table.read_rows_sharded([query_1, query_2]) + assert len(result) == 2 + assert result[0].row_key == b"test_1" + assert result[1].row_key == b"test_2" + + @pytest.mark.parametrize("n_queries", [1, 2, 5, 11, 24]) + def test_read_rows_sharded_multiple_queries_calls(self, n_queries): + """Each query should trigger a separate read_rows call""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + query_list = [ReadRowsQuery() for _ in range(n_queries)] + table.read_rows_sharded(query_list) + assert read_rows.call_count == n_queries + + def test_read_rows_sharded_errors(self): + """Errors should be exposed as ShardedReadRowsExceptionGroups""" + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.cloud.bigtable.data.exceptions import FailedQueryShardError + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = RuntimeError("mock error") + query_1 = ReadRowsQuery(b"test_1") + query_2 = ReadRowsQuery(b"test_2") + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + table.read_rows_sharded([query_1, query_2]) + exc_group = exc.value + assert isinstance(exc_group, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == 2 + assert isinstance(exc.value.exceptions[0], FailedQueryShardError) + assert isinstance(exc.value.exceptions[0].__cause__, RuntimeError) + assert exc.value.exceptions[0].index == 0 + assert exc.value.exceptions[0].query == query_1 + assert isinstance(exc.value.exceptions[1], FailedQueryShardError) + assert isinstance(exc.value.exceptions[1].__cause__, RuntimeError) + assert exc.value.exceptions[1].index == 1 + assert exc.value.exceptions[1].query == query_2 + + def test_read_rows_sharded_concurrent(self): + """Ensure sharded requests are concurrent""" + import time + + def mock_call(*args, **kwargs): + CrossSync._Sync_Impl.sleep(0.1) + return [mock.Mock()] + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(10)] + start_time = time.monotonic() + result = table.read_rows_sharded(queries) + call_time = time.monotonic() - start_time + assert read_rows.call_count == 10 + assert len(result) == 10 + assert call_time < 0.2 + + def test_read_rows_sharded_concurrency_limit(self): + """Only 10 queries should be processed concurrently. Others should be queued + + Should start a new query as soon as previous finishes""" + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT + + assert _CONCURRENCY_LIMIT == 10 + num_queries = 15 + increment_time = 0.05 + max_time = increment_time * (_CONCURRENCY_LIMIT - 1) + rpc_times = [min(i * increment_time, max_time) for i in range(num_queries)] + + def mock_call(*args, **kwargs): + next_sleep = rpc_times.pop(0) + asyncio.sleep(next_sleep) + return [mock.Mock()] + + starting_timeout = 10 + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(num_queries)] + table.read_rows_sharded(queries, operation_timeout=starting_timeout) + assert read_rows.call_count == num_queries + rpc_start_list = [ + starting_timeout - kwargs["operation_timeout"] + for _, kwargs in read_rows.call_args_list + ] + eps = 0.01 + assert all( + (rpc_start_list[i] < eps for i in range(_CONCURRENCY_LIMIT)) + ) + for i in range(num_queries - _CONCURRENCY_LIMIT): + idx = i + _CONCURRENCY_LIMIT + assert rpc_start_list[idx] - i * increment_time < eps + + def test_read_rows_sharded_expirary(self): + """If the operation times out before all shards complete, should raise + a ShardedReadRowsExceptionGroup""" + from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded + + operation_timeout = 0.1 + num_queries = 15 + sleeps = [0] * _CONCURRENCY_LIMIT + [DeadlineExceeded("times up")] * ( + num_queries - _CONCURRENCY_LIMIT + ) + + def mock_call(*args, **kwargs): + next_item = sleeps.pop(0) + if isinstance(next_item, Exception): + raise next_item + else: + asyncio.sleep(next_item) + return [mock.Mock()] + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(num_queries)] + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + table.read_rows_sharded( + queries, operation_timeout=operation_timeout + ) + assert isinstance(exc.value, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == num_queries - _CONCURRENCY_LIMIT + assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT + + def test_read_rows_sharded_negative_batch_timeout(self): + """try to run with batch that starts after operation timeout + + They should raise DeadlineExceeded errors""" + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded + + def mock_call(*args, **kwargs): + CrossSync._Sync_Impl.sleep(0.05) + return [mock.Mock()] + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(15)] + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + table.read_rows_sharded(queries, operation_timeout=0.01) + assert isinstance(exc.value, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == 5 + assert all( + ( + isinstance(e.__cause__, DeadlineExceeded) + for e in exc.value.exceptions + ) + ) + + +class TestSampleRowKeys: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): + from google.cloud.bigtable_v2.types import SampleRowKeysResponse + + for value in sample_list: + yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) + + def test_sample_row_keys(self): + """Test that method returns the expected key samples""" + samples = [(b"test_1", 0), (b"test_2", 100), (b"test_3", 200)] + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream(samples) + result = table.sample_row_keys() + assert len(result) == 3 + assert all((isinstance(r, tuple) for r in result)) + assert all((isinstance(r[0], bytes) for r in result)) + assert all((isinstance(r[1], int) for r in result)) + assert result[0] == samples[0] + assert result[1] == samples[1] + assert result[2] == samples[2] + + def test_sample_row_keys_bad_timeout(self): + """should raise error if timeout is negative""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.sample_row_keys(operation_timeout=-1) + assert "operation_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + table.sample_row_keys(attempt_timeout=-1) + assert "attempt_timeout must be greater than 0" in str(e.value) + + def test_sample_row_keys_default_timeout(self): + """Should fallback to using table default operation_timeout""" + expected_timeout = 99 + with self._make_client() as client: + with client.get_table( + "i", + "t", + default_operation_timeout=expected_timeout, + default_attempt_timeout=expected_timeout, + ) as table: + with mock.patch.object( + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream([]) + result = table.sample_row_keys() + _, kwargs = sample_row_keys.call_args + assert abs(kwargs["timeout"] - expected_timeout) < 0.1 + assert result == [] + assert kwargs["retry"] is None + + def test_sample_row_keys_gapic_params(self): + """make sure arguments are propagated to gapic call as expected""" + expected_timeout = 10 + expected_profile = "test1" + instance = "instance_name" + table_id = "my_table" + with self._make_client() as client: + with client.get_table( + instance, table_id, app_profile_id=expected_profile + ) as table: + with mock.patch.object( + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), + ) as sample_row_keys: + sample_row_keys.return_value = self._make_gapic_stream([]) + table.sample_row_keys(attempt_timeout=expected_timeout) + args, kwargs = sample_row_keys.call_args + assert len(args) == 0 + assert len(kwargs) == 5 + assert kwargs["timeout"] == expected_timeout + assert kwargs["app_profile_id"] == expected_profile + assert kwargs["table_name"] == table.table_name + assert kwargs["metadata"] is not None + assert kwargs["retry"] is None + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_sample_row_keys_retryable_errors(self, retryable_exception): + """retryable errors should be retried until timeout""" + from google.api_core.exceptions import DeadlineExceeded + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), + ) as sample_row_keys: + sample_row_keys.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + table.sample_row_keys(operation_timeout=0.05) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert len(cause.exceptions) > 0 + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + core_exceptions.Aborted, + ], + ) + def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): + """non-retryable errors should cause a raise""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + table.client._gapic_client, + "sample_row_keys", + CrossSync._Sync_Impl.Mock(), + ) as sample_row_keys: + sample_row_keys.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + table.sample_row_keys() + + +class TestMutateRow: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + @pytest.mark.parametrize( + "mutation_arg", + [ + mutations.SetCell("family", b"qualifier", b"value"), + mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=1234567890 + ), + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromFamily("family"), + mutations.DeleteAllFromRow(), + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromRow(), + ], + ], + ) + def test_mutate_row(self, mutation_arg): + """Test mutations with no errors""" + expected_attempt_timeout = 19 + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.return_value = None + table.mutate_row( + "row_key", + mutation_arg, + attempt_timeout=expected_attempt_timeout, + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0].kwargs + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["row_key"] == b"row_key" + formatted_mutations = ( + [mutation._to_pb() for mutation in mutation_arg] + if isinstance(mutation_arg, list) + else [mutation_arg._to_pb()] + ) + assert kwargs["mutations"] == formatted_mutations + assert kwargs["timeout"] == expected_attempt_timeout + assert kwargs["retry"] is None + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_mutate_row_retryable_errors(self, retryable_exception): + from google.api_core.exceptions import DeadlineExceeded + from google.cloud.bigtable.data.exceptions import RetryExceptionGroup + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + mutation = mutations.DeleteAllFromRow() + assert mutation.is_idempotent() is True + table.mutate_row("row_key", mutation, operation_timeout=0.01) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_mutate_row_non_idempotent_retryable_errors(self, retryable_exception): + """Non-idempotent mutations should not be retried""" + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(retryable_exception): + mutation = mutations.SetCell( + "family", b"qualifier", b"value", -1 + ) + assert mutation.is_idempotent() is False + table.mutate_row("row_key", mutation, operation_timeout=0.2) + + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + core_exceptions.Aborted, + ], + ) + def test_mutate_row_non_retryable_errors(self, non_retryable_exception): + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_row" + ) as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + mutation = mutations.SetCell( + "family", + b"qualifier", + b"value", + timestamp_micros=1234567890, + ) + assert mutation.is_idempotent() is True + table.mutate_row("row_key", mutation, operation_timeout=0.2) + + @pytest.mark.parametrize("include_app_profile", [True, False]) + def test_mutate_row_metadata(self, include_app_profile): + """request should attach metadata headers""" + profile = "profile" if include_app_profile else None + with self._make_client() as client: + with client.get_table("i", "t", app_profile_id=profile) as table: + with mock.patch.object( + client._gapic_client, "mutate_row", CrossSync._Sync_Impl.Mock() + ) as read_rows: + table.mutate_row("rk", mock.Mock()) + kwargs = read_rows.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata + + @pytest.mark.parametrize("mutations", [[], None]) + def test_mutate_row_no_mutations(self, mutations): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.mutate_row("key", mutations=mutations) + assert e.value.args[0] == "No mutations provided" + + +class TestBulkMutateRows: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + def _mock_response(self, response_list): + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc import status_pb2 + + statuses = [] + for response in response_list: + if isinstance(response, core_exceptions.GoogleAPICallError): + statuses.append( + status_pb2.Status( + message=str(response), code=response.grpc_status_code.value[0] + ) + ) + else: + statuses.append(status_pb2.Status(code=0)) + entries = [ + MutateRowsResponse.Entry(index=i, status=statuses[i]) + for i in range(len(response_list)) + ] + + def generator(): + yield MutateRowsResponse(entries=entries) + + return generator() + + @pytest.mark.parametrize( + "mutation_arg", + [ + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=1234567890 + ) + ], + [mutations.DeleteRangeFromColumn("family", b"qualifier")], + [mutations.DeleteAllFromFamily("family")], + [mutations.DeleteAllFromRow()], + [mutations.SetCell("family", b"qualifier", b"value")], + [ + mutations.DeleteRangeFromColumn("family", b"qualifier"), + mutations.DeleteAllFromRow(), + ], + ], + ) + def test_bulk_mutate_rows(self, mutation_arg): + """Test mutations with no errors""" + expected_attempt_timeout = 19 + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.return_value = self._mock_response([None]) + bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg) + table.bulk_mutate_rows( + [bulk_mutation], attempt_timeout=expected_attempt_timeout + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"] == [bulk_mutation._to_pb()] + assert kwargs["timeout"] == expected_attempt_timeout + assert kwargs["retry"] is None + + def test_bulk_mutate_rows_multiple_entries(self): + """Test mutations with no errors""" + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.return_value = self._mock_response([None, None]) + mutation_list = [mutations.DeleteAllFromRow()] + entry_1 = mutations.RowMutationEntry(b"row_key_1", mutation_list) + entry_2 = mutations.RowMutationEntry(b"row_key_2", mutation_list) + table.bulk_mutate_rows([entry_1, entry_2]) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"][0] == entry_1._to_pb() + assert kwargs["entries"][1] == entry_2._to_pb() + + @pytest.mark.parametrize( + "exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_rows_idempotent_mutation_error_retryable(self, exception): + """Individual idempotent mutations should be retried if they fail with a retryable error""" + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], exception) + assert isinstance( + cause.exceptions[-1], core_exceptions.DeadlineExceeded + ) + + @pytest.mark.parametrize( + "exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + core_exceptions.Aborted, + ], + ) + def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable(self, exception): + """Individual idempotent mutations should not be retried if they fail with a non-retryable error""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, exception) + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_idempotent_retryable_request_errors(self, retryable_exception): + """Individual idempotent mutations should be retried if the request fails with a retryable error""" + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) + + @pytest.mark.parametrize( + "retryable_exception", + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ) + def test_bulk_mutate_rows_non_idempotent_retryable_errors( + self, retryable_exception + ): + """Non-Idempotent mutations should never be retried""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [retryable_exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", -1 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is False + table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, retryable_exception) + + @pytest.mark.parametrize( + "non_retryable_exception", + [ + core_exceptions.OutOfRange, + core_exceptions.NotFound, + core_exceptions.FailedPrecondition, + RuntimeError, + ValueError, + ], + ) + def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_exception): + """If the request fails with a non-retryable error, mutations should not be retried""" + from google.cloud.bigtable.data.exceptions import ( + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, non_retryable_exception) + + def test_bulk_mutate_error_index(self): + """Test partial failure, partial success. Errors should be associated with the correct index""" + from google.api_core.exceptions import ( + DeadlineExceeded, + ServiceUnavailable, + FailedPrecondition, + ) + from google.cloud.bigtable.data.exceptions import ( + RetryExceptionGroup, + FailedMutationEntryError, + MutationsExceptionGroup, + ) + + with self._make_client(project="project") as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "mutate_rows" + ) as mock_gapic: + mock_gapic.side_effect = [ + self._mock_response([None, ServiceUnavailable("mock"), None]), + self._mock_response([DeadlineExceeded("mock")]), + self._mock_response([FailedPrecondition("final")]), + ] + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entries = [ + mutations.RowMutationEntry( + f"row_key_{i}".encode(), [mutation] + ) + for i in range(3) + ] + assert mutation.is_idempotent() is True + table.bulk_mutate_rows(entries, operation_timeout=1000) + assert len(e.value.exceptions) == 1 + failed = e.value.exceptions[0] + assert isinstance(failed, FailedMutationEntryError) + assert failed.index == 1 + assert failed.entry == entries[1] + cause = failed.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert len(cause.exceptions) == 3 + assert isinstance(cause.exceptions[0], ServiceUnavailable) + assert isinstance(cause.exceptions[1], DeadlineExceeded) + assert isinstance(cause.exceptions[2], FailedPrecondition) + + def test_bulk_mutate_error_recovery(self): + """If an error occurs, then resolves, no exception should be raised""" + from google.api_core.exceptions import DeadlineExceeded + + with self._make_client(project="project") as client: + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.side_effect = [ + self._mock_response([DeadlineExceeded("mock")]), + self._mock_response([None]), + ] + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entries = [ + mutations.RowMutationEntry(f"row_key_{i}".encode(), [mutation]) + for i in range(3) + ] + table.bulk_mutate_rows(entries, operation_timeout=1000) + + +class TestCheckAndMutateRow: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + @pytest.mark.parametrize("gapic_result", [True, False]) + def test_check_and_mutate(self, gapic_result): + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + app_profile = "app_profile_id" + with self._make_client() as client: + with client.get_table( + "instance", "table", app_profile_id=app_profile + ) as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=gapic_result + ) + row_key = b"row_key" + predicate = None + true_mutations = [mock.Mock()] + false_mutations = [mock.Mock(), mock.Mock()] + operation_timeout = 0.2 + found = table.check_and_mutate_row( + row_key, + predicate, + true_case_mutations=true_mutations, + false_case_mutations=false_mutations, + operation_timeout=operation_timeout, + ) + assert found == gapic_result + kwargs = mock_gapic.call_args[1] + assert kwargs["table_name"] == table.table_name + assert kwargs["row_key"] == row_key + assert kwargs["predicate_filter"] == predicate + assert kwargs["true_mutations"] == [ + m._to_pb() for m in true_mutations + ] + assert kwargs["false_mutations"] == [ + m._to_pb() for m in false_mutations + ] + assert kwargs["app_profile_id"] == app_profile + assert kwargs["timeout"] == operation_timeout + assert kwargs["retry"] is None + + def test_check_and_mutate_bad_timeout(self): + """Should raise error if operation_timeout < 0""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=[mock.Mock()], + false_case_mutations=[], + operation_timeout=-1, + ) + assert str(e.value) == "operation_timeout must be greater than 0" + + def test_check_and_mutate_single_mutations(self): + """if single mutations are passed, they should be internally wrapped in a list""" + from google.cloud.bigtable.data.mutations import SetCell + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + true_mutation = SetCell("family", b"qualifier", b"value") + false_mutation = SetCell("family", b"qualifier", b"value") + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=true_mutation, + false_case_mutations=false_mutation, + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["true_mutations"] == [true_mutation._to_pb()] + assert kwargs["false_mutations"] == [false_mutation._to_pb()] + + def test_check_and_mutate_predicate_object(self): + """predicate filter should be passed to gapic request""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + + mock_predicate = mock.Mock() + predicate_pb = {"predicate": "dict"} + mock_predicate._to_pb.return_value = predicate_pb + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + table.check_and_mutate_row( + b"row_key", mock_predicate, false_case_mutations=[mock.Mock()] + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["predicate_filter"] == predicate_pb + assert mock_predicate._to_pb.call_count == 1 + assert kwargs["retry"] is None + + def test_check_and_mutate_mutations_parsing(self): + """mutations objects should be converted to protos""" + from google.cloud.bigtable_v2.types import CheckAndMutateRowResponse + from google.cloud.bigtable.data.mutations import DeleteAllFromRow + + mutations = [mock.Mock() for _ in range(5)] + for idx, mutation in enumerate(mutations): + mutation._to_pb.return_value = f"fake {idx}" + mutations.append(DeleteAllFromRow()) + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "check_and_mutate_row" + ) as mock_gapic: + mock_gapic.return_value = CheckAndMutateRowResponse( + predicate_matched=True + ) + table.check_and_mutate_row( + b"row_key", + None, + true_case_mutations=mutations[0:2], + false_case_mutations=mutations[2:], + ) + kwargs = mock_gapic.call_args[1] + assert kwargs["true_mutations"] == ["fake 0", "fake 1"] + assert kwargs["false_mutations"] == [ + "fake 2", + "fake 3", + "fake 4", + DeleteAllFromRow()._to_pb(), + ] + assert all( + (mutation._to_pb.call_count == 1 for mutation in mutations[:5]) + ) + + +class TestReadModifyWriteRow: + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + @pytest.mark.parametrize( + "call_rules,expected_rules", + [ + ( + AppendValueRule("f", "c", b"1"), + [AppendValueRule("f", "c", b"1")._to_pb()], + ), + ( + [AppendValueRule("f", "c", b"1")], + [AppendValueRule("f", "c", b"1")._to_pb()], + ), + (IncrementRule("f", "c", 1), [IncrementRule("f", "c", 1)._to_pb()]), + ( + [AppendValueRule("f", "c", b"1"), IncrementRule("f", "c", 1)], + [ + AppendValueRule("f", "c", b"1")._to_pb(), + IncrementRule("f", "c", 1)._to_pb(), + ], + ), + ], + ) + def test_read_modify_write_call_rule_args(self, call_rules, expected_rules): + """Test that the gapic call is called with given rules""" + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row("key", call_rules) + assert mock_gapic.call_count == 1 + found_kwargs = mock_gapic.call_args_list[0][1] + assert found_kwargs["rules"] == expected_rules + assert found_kwargs["retry"] is None + + @pytest.mark.parametrize("rules", [[], None]) + def test_read_modify_write_no_rules(self, rules): + with self._make_client() as client: + with client.get_table("instance", "table") as table: + with pytest.raises(ValueError) as e: + table.read_modify_write_row("key", rules=rules) + assert e.value.args[0] == "rules must contain at least one item" + + def test_read_modify_write_call_defaults(self): + instance = "instance1" + table_id = "table1" + project = "project1" + row_key = "row_key1" + with self._make_client(project=project) as client: + with client.get_table(instance, table_id) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert ( + kwargs["table_name"] + == f"projects/{project}/instances/{instance}/tables/{table_id}" + ) + assert kwargs["app_profile_id"] is None + assert kwargs["row_key"] == row_key.encode() + assert kwargs["timeout"] > 1 + + def test_read_modify_write_call_overrides(self): + row_key = b"row_key1" + expected_timeout = 12345 + profile_id = "profile1" + with self._make_client() as client: + with client.get_table( + "instance", "table_id", app_profile_id=profile_id + ) as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row( + row_key, mock.Mock(), operation_timeout=expected_timeout + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["app_profile_id"] is profile_id + assert kwargs["row_key"] == row_key + assert kwargs["timeout"] == expected_timeout + + def test_read_modify_write_string_key(self): + row_key = "string_row_key1" + with self._make_client() as client: + with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + table.read_modify_write_row(row_key, mock.Mock()) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0][1] + assert kwargs["row_key"] == row_key.encode() + + def test_read_modify_write_row_building(self): + """results from gapic call should be used to construct row""" + from google.cloud.bigtable.data.row import Row + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + from google.cloud.bigtable_v2.types import Row as RowPB + + mock_response = ReadModifyWriteRowResponse(row=RowPB()) + with self._make_client() as client: + with client.get_table("instance", "table_id") as table: + with mock.patch.object( + client._gapic_client, "read_modify_write_row" + ) as mock_gapic: + with mock.patch.object(Row, "_from_pb") as constructor_mock: + mock_gapic.return_value = mock_response + table.read_modify_write_row("key", mock.Mock()) + assert constructor_mock.call_count == 1 + constructor_mock.assert_called_once_with(mock_response.row) diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py new file mode 100644 index 000000000..49cc3efeb --- /dev/null +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -0,0 +1,1081 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by CrossSync. Do not edit manually. +import pytest +import asyncio +import time +import google.api_core.exceptions as core_exceptions +import google.api_core.retry +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data import TABLE_DEFAULT +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +try: + from unittest import mock +except ImportError: + import mock + + +class Test_FlowControl: + @staticmethod + def _target_class(): + return CrossSync._Sync_Impl._FlowControl + + def _make_one(self, max_mutation_count=10, max_mutation_bytes=100): + return self._target_class()(max_mutation_count, max_mutation_bytes) + + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def test_ctor(self): + max_mutation_count = 9 + max_mutation_bytes = 19 + instance = self._make_one(max_mutation_count, max_mutation_bytes) + assert instance._max_mutation_count == max_mutation_count + assert instance._max_mutation_bytes == max_mutation_bytes + assert instance._in_flight_mutation_count == 0 + assert instance._in_flight_mutation_bytes == 0 + assert isinstance(instance._capacity_condition, CrossSync._Sync_Impl.Condition) + + def test_ctor_invalid_values(self): + """Test that values are positive, and fit within expected limits""" + with pytest.raises(ValueError) as e: + self._make_one(0, 1) + assert "max_mutation_count must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._make_one(1, 0) + assert "max_mutation_bytes must be greater than 0" in str(e.value) + + @pytest.mark.parametrize( + "max_count,max_size,existing_count,existing_size,new_count,new_size,expected", + [ + (1, 1, 0, 0, 0, 0, True), + (1, 1, 1, 1, 1, 1, False), + (10, 10, 0, 0, 0, 0, True), + (10, 10, 0, 0, 9, 9, True), + (10, 10, 0, 0, 11, 9, True), + (10, 10, 0, 1, 11, 9, True), + (10, 10, 1, 0, 11, 9, False), + (10, 10, 0, 0, 9, 11, True), + (10, 10, 1, 0, 9, 11, True), + (10, 10, 0, 1, 9, 11, False), + (10, 1, 0, 0, 1, 0, True), + (1, 10, 0, 0, 0, 8, True), + (float("inf"), float("inf"), 0, 0, 10000000000.0, 10000000000.0, True), + (8, 8, 0, 0, 10000000000.0, 10000000000.0, True), + (12, 12, 6, 6, 5, 5, True), + (12, 12, 5, 5, 6, 6, True), + (12, 12, 6, 6, 6, 6, True), + (12, 12, 6, 6, 7, 7, False), + (12, 12, 0, 0, 13, 13, True), + (12, 12, 12, 0, 0, 13, True), + (12, 12, 0, 12, 13, 0, True), + (12, 12, 1, 1, 13, 13, False), + (12, 12, 1, 1, 0, 13, False), + (12, 12, 1, 1, 13, 0, False), + ], + ) + def test__has_capacity( + self, + max_count, + max_size, + existing_count, + existing_size, + new_count, + new_size, + expected, + ): + """_has_capacity should return True if the new mutation will will not exceed the max count or size""" + instance = self._make_one(max_count, max_size) + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + assert instance._has_capacity(new_count, new_size) == expected + + @pytest.mark.parametrize( + "existing_count,existing_size,added_count,added_size,new_count,new_size", + [ + (0, 0, 0, 0, 0, 0), + (2, 2, 1, 1, 1, 1), + (2, 0, 1, 0, 1, 0), + (0, 2, 0, 1, 0, 1), + (10, 10, 0, 0, 10, 10), + (10, 10, 5, 5, 5, 5), + (0, 0, 1, 1, -1, -1), + ], + ) + def test_remove_from_flow_value_update( + self, + existing_count, + existing_size, + added_count, + added_size, + new_count, + new_size, + ): + """completed mutations should lower the inflight values""" + instance = self._make_one() + instance._in_flight_mutation_count = existing_count + instance._in_flight_mutation_bytes = existing_size + mutation = self._make_mutation(added_count, added_size) + instance.remove_from_flow(mutation) + assert instance._in_flight_mutation_count == new_count + assert instance._in_flight_mutation_bytes == new_size + + def test__remove_from_flow_unlock(self): + """capacity condition should notify after mutation is complete""" + instance = self._make_one(10, 10) + instance._in_flight_mutation_count = 10 + instance._in_flight_mutation_bytes = 10 + + def task_routine(): + with instance._capacity_condition: + instance._capacity_condition.wait_for( + lambda: instance._has_capacity(1, 1) + ) + + if CrossSync._Sync_Impl.is_async: + task = asyncio.create_task(task_routine()) + + def task_alive(): + return not task.done() + + else: + import threading + + thread = threading.Thread(target=task_routine) + thread.start() + task_alive = thread.is_alive + CrossSync._Sync_Impl.sleep(0.05) + assert task_alive() is True + mutation = self._make_mutation(count=0, size=5) + instance.remove_from_flow([mutation]) + CrossSync._Sync_Impl.sleep(0.05) + assert instance._in_flight_mutation_count == 10 + assert instance._in_flight_mutation_bytes == 5 + assert task_alive() is True + instance._in_flight_mutation_bytes = 10 + mutation = self._make_mutation(count=5, size=0) + instance.remove_from_flow([mutation]) + CrossSync._Sync_Impl.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 10 + assert task_alive() is True + instance._in_flight_mutation_count = 10 + mutation = self._make_mutation(count=5, size=5) + instance.remove_from_flow([mutation]) + CrossSync._Sync_Impl.sleep(0.05) + assert instance._in_flight_mutation_count == 5 + assert instance._in_flight_mutation_bytes == 5 + assert task_alive() is False + + @pytest.mark.parametrize( + "mutations,count_cap,size_cap,expected_results", + [ + ([(5, 5), (1, 1), (1, 1)], 10, 10, [[(5, 5), (1, 1), (1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 1, 1, [[(1, 1)], [(1, 1)], [(1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 2, 10, [[(1, 1), (1, 1)], [(1, 1)]]), + ([(1, 1), (1, 1), (1, 1)], 10, 2, [[(1, 1), (1, 1)], [(1, 1)]]), + ( + [(1, 1), (5, 5), (4, 1), (1, 4), (1, 1)], + 5, + 5, + [[(1, 1)], [(5, 5)], [(4, 1), (1, 4)], [(1, 1)]], + ), + ], + ) + def test_add_to_flow(self, mutations, count_cap, size_cap, expected_results): + """Test batching with various flow control settings""" + mutation_objs = [self._make_mutation(count=m[0], size=m[1]) for m in mutations] + instance = self._make_one(count_cap, size_cap) + i = 0 + for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + assert len(batch[j].mutations) == expected_batch[j][0] + assert batch[j].size() == expected_batch[j][1] + instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + @pytest.mark.parametrize( + "mutations,max_limit,expected_results", + [ + ([(1, 1)] * 11, 10, [[(1, 1)] * 10, [(1, 1)]]), + ([(1, 1)] * 10, 1, [[(1, 1)] for _ in range(10)]), + ([(1, 1)] * 10, 2, [[(1, 1), (1, 1)] for _ in range(5)]), + ], + ) + def test_add_to_flow_max_mutation_limits( + self, mutations, max_limit, expected_results + ): + """Test flow control running up against the max API limit + Should submit request early, even if the flow control has room for more""" + subpath = "_async" if CrossSync._Sync_Impl.is_async else "_sync" + path = f"google.cloud.bigtable.data.{subpath}.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT" + with mock.patch(path, max_limit): + mutation_objs = [ + self._make_mutation(count=m[0], size=m[1]) for m in mutations + ] + instance = self._make_one(float("inf"), float("inf")) + i = 0 + for batch in instance.add_to_flow(mutation_objs): + expected_batch = expected_results[i] + assert len(batch) == len(expected_batch) + for j in range(len(expected_batch)): + assert len(batch[j].mutations) == expected_batch[j][0] + assert batch[j].size() == expected_batch[j][1] + instance.remove_from_flow(batch) + i += 1 + assert i == len(expected_results) + + def test_add_to_flow_oversize(self): + """mutations over the flow control limits should still be accepted""" + instance = self._make_one(2, 3) + large_size_mutation = self._make_mutation(count=1, size=10) + large_count_mutation = self._make_mutation(count=10, size=1) + results = [out for out in instance.add_to_flow([large_size_mutation])] + assert len(results) == 1 + instance.remove_from_flow(results[0]) + count_results = [out for out in instance.add_to_flow(large_count_mutation)] + assert len(count_results) == 1 + + +class TestMutationsBatcher: + def _get_target_class(self): + return CrossSync._Sync_Impl.MutationsBatcher + + def _make_one(self, table=None, **kwargs): + from google.api_core.exceptions import DeadlineExceeded + from google.api_core.exceptions import ServiceUnavailable + + if table is None: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 10 + table.default_mutate_rows_retryable_errors = ( + DeadlineExceeded, + ServiceUnavailable, + ) + return self._get_target_class()(table, **kwargs) + + @staticmethod + def _make_mutation(count=1, size=1): + mutation = mock.Mock() + mutation.size.return_value = size + mutation.mutations = [mock.Mock()] * count + return mutation + + def test_ctor_defaults(self): + with mock.patch.object( + self._get_target_class(), + "_timer_routine", + return_value=CrossSync._Sync_Impl.Future(), + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = [Exception] + with self._make_one(table) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._max_mutation_count == 100000 + assert instance._flow_control._max_mutation_bytes == 104857600 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert ( + instance._operation_timeout + == table.default_mutate_rows_operation_timeout + ) + assert ( + instance._attempt_timeout + == table.default_mutate_rows_attempt_timeout + ) + assert ( + instance._retryable_errors + == table.default_mutate_rows_retryable_errors + ) + CrossSync._Sync_Impl.yield_to_event_loop() + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == 5 + assert isinstance(instance._flush_timer, CrossSync._Sync_Impl.Future) + + def test_ctor_explicit(self): + """Test with explicit parameters""" + with mock.patch.object( + self._get_target_class(), + "_timer_routine", + return_value=CrossSync._Sync_Impl.Future(), + ) as flush_timer_mock: + table = mock.Mock() + flush_interval = 20 + flush_limit_count = 17 + flush_limit_bytes = 19 + flow_control_max_mutation_count = 1001 + flow_control_max_bytes = 12 + operation_timeout = 11 + attempt_timeout = 2 + retryable_errors = [Exception] + with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + flow_control_max_mutation_count=flow_control_max_mutation_count, + flow_control_max_bytes=flow_control_max_bytes, + batch_operation_timeout=operation_timeout, + batch_attempt_timeout=attempt_timeout, + batch_retryable_errors=retryable_errors, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._flush_jobs == set() + assert len(instance._staged_entries) == 0 + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert ( + instance._flow_control._max_mutation_count + == flow_control_max_mutation_count + ) + assert ( + instance._flow_control._max_mutation_bytes == flow_control_max_bytes + ) + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + assert instance._operation_timeout == operation_timeout + assert instance._attempt_timeout == attempt_timeout + assert instance._retryable_errors == retryable_errors + CrossSync._Sync_Impl.yield_to_event_loop() + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] == flush_interval + assert isinstance(instance._flush_timer, CrossSync._Sync_Impl.Future) + + def test_ctor_no_flush_limits(self): + """Test with None for flush limits""" + with mock.patch.object( + self._get_target_class(), + "_timer_routine", + return_value=CrossSync._Sync_Impl.Future(), + ) as flush_timer_mock: + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 10 + table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = () + flush_interval = None + flush_limit_count = None + flush_limit_bytes = None + with self._make_one( + table, + flush_interval=flush_interval, + flush_limit_mutation_count=flush_limit_count, + flush_limit_bytes=flush_limit_bytes, + ) as instance: + assert instance._table == table + assert instance.closed is False + assert instance._staged_entries == [] + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert instance._exception_list_limit == 10 + assert instance._exceptions_since_last_raise == 0 + assert instance._flow_control._in_flight_mutation_count == 0 + assert instance._flow_control._in_flight_mutation_bytes == 0 + assert instance._entries_processed_since_last_raise == 0 + CrossSync._Sync_Impl.yield_to_event_loop() + assert flush_timer_mock.call_count == 1 + assert flush_timer_mock.call_args[0][0] is None + assert isinstance(instance._flush_timer, CrossSync._Sync_Impl.Future) + + def test_ctor_invalid_values(self): + """Test that timeout values are positive, and fit within expected limits""" + with pytest.raises(ValueError) as e: + self._make_one(batch_operation_timeout=-1) + assert "operation_timeout must be greater than 0" in str(e.value) + with pytest.raises(ValueError) as e: + self._make_one(batch_attempt_timeout=-1) + assert "attempt_timeout must be greater than 0" in str(e.value) + + def test_default_argument_consistency(self): + """We supply default arguments in MutationsBatcherAsync.__init__, and in + table.mutations_batcher. Make sure any changes to defaults are applied to + both places""" + import inspect + + get_batcher_signature = dict( + inspect.signature(CrossSync._Sync_Impl.Table.mutations_batcher).parameters + ) + get_batcher_signature.pop("self") + batcher_init_signature = dict( + inspect.signature(self._get_target_class()).parameters + ) + batcher_init_signature.pop("table") + assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) + assert len(get_batcher_signature) == 8 + assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) + for arg_name in get_batcher_signature.keys(): + assert ( + get_batcher_signature[arg_name].default + == batcher_init_signature[arg_name].default + ) + + @pytest.mark.parametrize("input_val", [None, 0, -1]) + def test__start_flush_timer_w_empty_input(self, input_val): + """Empty/invalid timer should return immediately""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + with self._make_one() as instance: + if CrossSync._Sync_Impl.is_async: + sleep_obj, sleep_method = (asyncio, "wait_for") + else: + sleep_obj, sleep_method = (instance._closed, "wait") + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + result = instance._timer_routine(input_val) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + assert result is None + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test__start_flush_timer_call_when_closed(self): + """closed batcher's timer should return immediately""" + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + with self._make_one() as instance: + instance.close() + flush_mock.reset_mock() + if CrossSync._Sync_Impl.is_async: + sleep_obj, sleep_method = (asyncio, "wait_for") + else: + sleep_obj, sleep_method = (instance._closed, "wait") + with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: + instance._timer_routine(10) + assert sleep_mock.call_count == 0 + assert flush_mock.call_count == 0 + + @pytest.mark.parametrize("num_staged", [0, 1, 10]) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test__flush_timer(self, num_staged): + """Timer should continue to call _schedule_flush in a loop""" + from google.cloud.bigtable.data._sync.cross_sync import CrossSync + + with mock.patch.object( + self._get_target_class(), "_schedule_flush" + ) as flush_mock: + expected_sleep = 12 + with self._make_one(flush_interval=expected_sleep) as instance: + loop_num = 3 + instance._staged_entries = [mock.Mock()] * num_staged + with mock.patch.object( + CrossSync._Sync_Impl, "event_wait" + ) as sleep_mock: + sleep_mock.side_effect = [None] * loop_num + [TabError("expected")] + with pytest.raises(TabError): + self._get_target_class()._timer_routine( + instance, expected_sleep + ) + if CrossSync._Sync_Impl.is_async: + instance._flush_timer = CrossSync._Sync_Impl.Future() + assert sleep_mock.call_count == loop_num + 1 + sleep_kwargs = sleep_mock.call_args[1] + assert sleep_kwargs["timeout"] == expected_sleep + assert flush_mock.call_count == (0 if num_staged == 0 else loop_num) + + def test__flush_timer_close(self): + """Timer should continue terminate after close""" + with mock.patch.object(self._get_target_class(), "_schedule_flush"): + with self._make_one() as instance: + assert instance._flush_timer.done() is False + instance.close() + assert instance._flush_timer.done() is True + + def test_append_closed(self): + """Should raise exception""" + instance = self._make_one() + instance.close() + with pytest.raises(RuntimeError): + instance.append(mock.Mock()) + + def test_append_wrong_mutation(self): + """Mutation objects should raise an exception. + Only support RowMutationEntry""" + from google.cloud.bigtable.data.mutations import DeleteAllFromRow + + with self._make_one() as instance: + expected_error = "invalid mutation type: DeleteAllFromRow. Only RowMutationEntry objects are supported by batcher" + with pytest.raises(ValueError) as e: + instance.append(DeleteAllFromRow()) + assert str(e.value) == expected_error + + def test_append_outside_flow_limits(self): + """entries larger than mutation limits are still processed""" + with self._make_one( + flow_control_max_mutation_count=1, flow_control_max_bytes=1 + ) as instance: + oversized_entry = self._make_mutation(count=0, size=2) + instance.append(oversized_entry) + assert instance._staged_entries == [oversized_entry] + assert instance._staged_count == 0 + assert instance._staged_bytes == 2 + instance._staged_entries = [] + with self._make_one( + flow_control_max_mutation_count=1, flow_control_max_bytes=1 + ) as instance: + overcount_entry = self._make_mutation(count=2, size=0) + instance.append(overcount_entry) + assert instance._staged_entries == [overcount_entry] + assert instance._staged_count == 2 + assert instance._staged_bytes == 0 + instance._staged_entries = [] + + def test_append_flush_runs_after_limit_hit(self): + """If the user appends a bunch of entries above the flush limits back-to-back, + it should still flush in a single task""" + with mock.patch.object( + self._get_target_class(), "_execute_mutate_rows" + ) as op_mock: + with self._make_one(flush_limit_bytes=100) as instance: + + def mock_call(*args, **kwargs): + return [] + + op_mock.side_effect = mock_call + instance.append(self._make_mutation(size=99)) + num_entries = 10 + for _ in range(num_entries): + instance.append(self._make_mutation(size=1)) + instance._wait_for_batch_results(*instance._flush_jobs) + assert op_mock.call_count == 1 + sent_batch = op_mock.call_args[0][0] + assert len(sent_batch) == 2 + assert len(instance._staged_entries) == num_entries - 1 + + @pytest.mark.parametrize( + "flush_count,flush_bytes,mutation_count,mutation_bytes,expect_flush", + [ + (10, 10, 1, 1, False), + (10, 10, 9, 9, False), + (10, 10, 10, 1, True), + (10, 10, 1, 10, True), + (10, 10, 10, 10, True), + (1, 1, 10, 10, True), + (1, 1, 0, 0, False), + ], + ) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_append( + self, flush_count, flush_bytes, mutation_count, mutation_bytes, expect_flush + ): + """test appending different mutations, and checking if it causes a flush""" + with self._make_one( + flush_limit_mutation_count=flush_count, flush_limit_bytes=flush_bytes + ) as instance: + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert instance._staged_entries == [] + mutation = self._make_mutation(count=mutation_count, size=mutation_bytes) + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + instance.append(mutation) + assert flush_mock.call_count == bool(expect_flush) + assert instance._staged_count == mutation_count + assert instance._staged_bytes == mutation_bytes + assert instance._staged_entries == [mutation] + instance._staged_entries = [] + + def test_append_multiple_sequentially(self): + """Append multiple mutations""" + with self._make_one( + flush_limit_mutation_count=8, flush_limit_bytes=8 + ) as instance: + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert instance._staged_entries == [] + mutation = self._make_mutation(count=2, size=3) + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + instance.append(mutation) + assert flush_mock.call_count == 0 + assert instance._staged_count == 2 + assert instance._staged_bytes == 3 + assert len(instance._staged_entries) == 1 + instance.append(mutation) + assert flush_mock.call_count == 0 + assert instance._staged_count == 4 + assert instance._staged_bytes == 6 + assert len(instance._staged_entries) == 2 + instance.append(mutation) + assert flush_mock.call_count == 1 + assert instance._staged_count == 6 + assert instance._staged_bytes == 9 + assert len(instance._staged_entries) == 3 + instance._staged_entries = [] + + def test_flush_flow_control_concurrent_requests(self): + """requests should happen in parallel if flow control breaks up single flush into batches""" + import time + + num_calls = 10 + fake_mutations = [self._make_mutation(count=1) for _ in range(num_calls)] + with self._make_one(flow_control_max_mutation_count=1) as instance: + with mock.patch.object( + instance, "_execute_mutate_rows", CrossSync._Sync_Impl.Mock() + ) as op_mock: + + def mock_call(*args, **kwargs): + CrossSync._Sync_Impl.sleep(0.1) + return [] + + op_mock.side_effect = mock_call + start_time = time.monotonic() + instance._staged_entries = fake_mutations + instance._schedule_flush() + CrossSync._Sync_Impl.sleep(0.01) + for i in range(num_calls): + instance._flow_control.remove_from_flow( + [self._make_mutation(count=1)] + ) + CrossSync._Sync_Impl.sleep(0.01) + instance._wait_for_batch_results(*instance._flush_jobs) + duration = time.monotonic() - start_time + assert len(instance._oldest_exceptions) == 0 + assert len(instance._newest_exceptions) == 0 + assert duration < 0.5 + assert op_mock.call_count == num_calls + + def test_schedule_flush_no_mutations(self): + """schedule flush should return None if no staged mutations""" + with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal") as flush_mock: + for i in range(3): + assert instance._schedule_flush() is None + assert flush_mock.call_count == 0 + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_schedule_flush_with_mutations(self): + """if new mutations exist, should add a new flush task to _flush_jobs""" + with self._make_one() as instance: + with mock.patch.object(instance, "_flush_internal") as flush_mock: + if not CrossSync._Sync_Impl.is_async: + flush_mock.side_effect = lambda x: time.sleep(0.1) + for i in range(1, 4): + mutation = mock.Mock() + instance._staged_entries = [mutation] + instance._schedule_flush() + assert instance._staged_entries == [] + asyncio.sleep(0) + assert instance._staged_entries == [] + assert instance._staged_count == 0 + assert instance._staged_bytes == 0 + assert flush_mock.call_count == 1 + flush_mock.reset_mock() + + def test__flush_internal(self): + """_flush_internal should: + - await previous flush call + - delegate batching to _flow_control + - call _execute_mutate_rows on each batch + - update self.exceptions and self._entries_processed_since_last_raise""" + num_entries = 10 + with self._make_one() as instance: + with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: + with mock.patch.object( + instance._flow_control, "add_to_flow" + ) as flow_mock: + + def gen(x): + yield x + + flow_mock.side_effect = lambda x: gen(x) + mutations = [self._make_mutation(count=1, size=1)] * num_entries + instance._flush_internal(mutations) + assert instance._entries_processed_since_last_raise == num_entries + assert execute_mock.call_count == 1 + assert flow_mock.call_count == 1 + instance._oldest_exceptions.clear() + instance._newest_exceptions.clear() + + def test_flush_clears_job_list(self): + """a job should be added to _flush_jobs when _schedule_flush is called, + and removed when it completes""" + with self._make_one() as instance: + with mock.patch.object( + instance, "_flush_internal", CrossSync._Sync_Impl.Mock() + ) as flush_mock: + if not CrossSync._Sync_Impl.is_async: + flush_mock.side_effect = lambda x: time.sleep(0.1) + mutations = [self._make_mutation(count=1, size=1)] + instance._staged_entries = mutations + assert instance._flush_jobs == set() + new_job = instance._schedule_flush() + assert instance._flush_jobs == {new_job} + if CrossSync._Sync_Impl.is_async: + new_job + else: + new_job.result() + assert instance._flush_jobs == set() + + @pytest.mark.parametrize( + "num_starting,num_new_errors,expected_total_errors", + [ + (0, 0, 0), + (0, 1, 1), + (0, 2, 2), + (1, 0, 1), + (1, 1, 2), + (10, 2, 12), + (10, 20, 20), + ], + ) + def test__flush_internal_with_errors( + self, num_starting, num_new_errors, expected_total_errors + ): + """errors returned from _execute_mutate_rows should be added to internal exceptions""" + from google.cloud.bigtable.data import exceptions + + num_entries = 10 + expected_errors = [ + exceptions.FailedMutationEntryError(mock.Mock(), mock.Mock(), ValueError()) + ] * num_new_errors + with self._make_one() as instance: + instance._oldest_exceptions = [mock.Mock()] * num_starting + with mock.patch.object(instance, "_execute_mutate_rows") as execute_mock: + execute_mock.return_value = expected_errors + with mock.patch.object( + instance._flow_control, "add_to_flow" + ) as flow_mock: + + def gen(x): + yield x + + flow_mock.side_effect = lambda x: gen(x) + mutations = [self._make_mutation(count=1, size=1)] * num_entries + instance._flush_internal(mutations) + assert instance._entries_processed_since_last_raise == num_entries + assert execute_mock.call_count == 1 + assert flow_mock.call_count == 1 + found_exceptions = instance._oldest_exceptions + list( + instance._newest_exceptions + ) + assert len(found_exceptions) == expected_total_errors + for i in range(num_starting, expected_total_errors): + assert found_exceptions[i] == expected_errors[i - num_starting] + assert found_exceptions[i].index is None + instance._oldest_exceptions.clear() + instance._newest_exceptions.clear() + + def _mock_gapic_return(self, num=5): + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc import status_pb2 + + def gen(num): + for i in range(num): + entry = MutateRowsResponse.Entry( + index=i, status=status_pb2.Status(code=0) + ) + yield MutateRowsResponse(entries=[entry]) + + return gen(num) + + def test_timer_flush_end_to_end(self): + """Flush should automatically trigger after flush_interval""" + num_mutations = 10 + mutations = [self._make_mutation(count=2, size=2)] * num_mutations + with self._make_one(flush_interval=0.05) as instance: + instance._table.default_operation_timeout = 10 + instance._table.default_attempt_timeout = 9 + with mock.patch.object( + instance._table.client._gapic_client, "mutate_rows" + ) as gapic_mock: + gapic_mock.side_effect = ( + lambda *args, **kwargs: self._mock_gapic_return(num_mutations) + ) + for m in mutations: + instance.append(m) + assert instance._entries_processed_since_last_raise == 0 + CrossSync._Sync_Impl.sleep(0.1) + assert instance._entries_processed_since_last_raise == num_mutations + + def test__execute_mutate_rows(self): + with mock.patch.object( + CrossSync._Sync_Impl, "_MutateRowsOperation" + ) as mutate_rows: + mutate_rows.return_value = CrossSync._Sync_Impl.Mock() + start_operation = mutate_rows().start + table = mock.Mock() + table.table_name = "test-table" + table.app_profile_id = "test-app-profile" + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = instance._execute_mutate_rows(batch) + assert start_operation.call_count == 1 + args, kwargs = mutate_rows.call_args + assert args[0] == table.client._gapic_client + assert args[1] == table + assert args[2] == batch + kwargs["operation_timeout"] == 17 + kwargs["attempt_timeout"] == 13 + assert result == [] + + def test__execute_mutate_rows_returns_errors(self): + """Errors from operation should be retruned as list""" + from google.cloud.bigtable.data.exceptions import ( + MutationsExceptionGroup, + FailedMutationEntryError, + ) + + with mock.patch.object( + CrossSync._Sync_Impl._MutateRowsOperation, "start" + ) as mutate_rows: + err1 = FailedMutationEntryError(0, mock.Mock(), RuntimeError("test error")) + err2 = FailedMutationEntryError(1, mock.Mock(), RuntimeError("test error")) + mutate_rows.side_effect = MutationsExceptionGroup([err1, err2], 10) + table = mock.Mock() + table.default_mutate_rows_operation_timeout = 17 + table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () + with self._make_one(table) as instance: + batch = [self._make_mutation()] + result = instance._execute_mutate_rows(batch) + assert len(result) == 2 + assert result[0] == err1 + assert result[1] == err2 + assert result[0].index is None + assert result[1].index is None + + def test__raise_exceptions(self): + """Raise exceptions and reset error state""" + from google.cloud.bigtable.data import exceptions + + expected_total = 1201 + expected_exceptions = [RuntimeError("mock")] * 3 + with self._make_one() as instance: + instance._oldest_exceptions = expected_exceptions + instance._entries_processed_since_last_raise = expected_total + try: + instance._raise_exceptions() + except exceptions.MutationsExceptionGroup as exc: + assert list(exc.exceptions) == expected_exceptions + assert str(expected_total) in str(exc) + assert instance._entries_processed_since_last_raise == 0 + instance._oldest_exceptions, instance._newest_exceptions = ([], []) + instance._raise_exceptions() + + def test___enter__(self): + """Should return self""" + with self._make_one() as instance: + assert instance.__enter__() == instance + + def test___exit__(self): + """aexit should call close""" + with self._make_one() as instance: + with mock.patch.object(instance, "close") as close_mock: + instance.__exit__(None, None, None) + assert close_mock.call_count == 1 + + def test_close(self): + """Should clean up all resources""" + with self._make_one() as instance: + with mock.patch.object(instance, "_schedule_flush") as flush_mock: + with mock.patch.object(instance, "_raise_exceptions") as raise_mock: + instance.close() + assert instance.closed is True + assert instance._flush_timer.done() is True + assert instance._flush_jobs == set() + assert flush_mock.call_count == 1 + assert raise_mock.call_count == 1 + + def test_close_w_exceptions(self): + """Raise exceptions on close""" + from google.cloud.bigtable.data import exceptions + + expected_total = 10 + expected_exceptions = [RuntimeError("mock")] + with self._make_one() as instance: + instance._oldest_exceptions = expected_exceptions + instance._entries_processed_since_last_raise = expected_total + try: + instance.close() + except exceptions.MutationsExceptionGroup as exc: + assert list(exc.exceptions) == expected_exceptions + assert str(expected_total) in str(exc) + assert instance._entries_processed_since_last_raise == 0 + instance._oldest_exceptions, instance._newest_exceptions = ([], []) + + def test__on_exit(self, recwarn): + """Should raise warnings if unflushed mutations exist""" + with self._make_one() as instance: + instance._on_exit() + assert len(recwarn) == 0 + num_left = 4 + instance._staged_entries = [mock.Mock()] * num_left + with pytest.warns(UserWarning) as w: + instance._on_exit() + assert len(w) == 1 + assert "unflushed mutations" in str(w[0].message).lower() + assert str(num_left) in str(w[0].message) + instance._closed.set() + instance._on_exit() + assert len(recwarn) == 0 + instance._staged_entries = [] + + def test_atexit_registration(self): + """Should run _on_exit on program termination""" + import atexit + + with mock.patch.object(atexit, "register") as register_mock: + assert register_mock.call_count == 0 + with self._make_one(): + assert register_mock.call_count == 1 + + def test_timeout_args_passed(self): + """batch_operation_timeout and batch_attempt_timeout should be used + in api calls""" + with mock.patch.object( + CrossSync._Sync_Impl, + "_MutateRowsOperation", + return_value=CrossSync._Sync_Impl.Mock(), + ) as mutate_rows: + expected_operation_timeout = 17 + expected_attempt_timeout = 13 + with self._make_one( + batch_operation_timeout=expected_operation_timeout, + batch_attempt_timeout=expected_attempt_timeout, + ) as instance: + assert instance._operation_timeout == expected_operation_timeout + assert instance._attempt_timeout == expected_attempt_timeout + instance._execute_mutate_rows([self._make_mutation()]) + assert mutate_rows.call_count == 1 + kwargs = mutate_rows.call_args[1] + assert kwargs["operation_timeout"] == expected_operation_timeout + assert kwargs["attempt_timeout"] == expected_attempt_timeout + + @pytest.mark.parametrize( + "limit,in_e,start_e,end_e", + [ + (10, 0, (10, 0), (10, 0)), + (1, 10, (0, 0), (1, 1)), + (10, 1, (0, 0), (1, 0)), + (10, 10, (0, 0), (10, 0)), + (10, 11, (0, 0), (10, 1)), + (3, 20, (0, 0), (3, 3)), + (10, 20, (0, 0), (10, 10)), + (10, 21, (0, 0), (10, 10)), + (2, 1, (2, 0), (2, 1)), + (2, 1, (1, 0), (2, 0)), + (2, 2, (1, 0), (2, 1)), + (3, 1, (3, 1), (3, 2)), + (3, 3, (3, 1), (3, 3)), + (1000, 5, (999, 0), (1000, 4)), + (1000, 5, (0, 0), (5, 0)), + (1000, 5, (1000, 0), (1000, 5)), + ], + ) + def test__add_exceptions(self, limit, in_e, start_e, end_e): + """Test that the _add_exceptions function properly updates the + _oldest_exceptions and _newest_exceptions lists + Args: + - limit: the _exception_list_limit representing the max size of either list + - in_e: size of list of exceptions to send to _add_exceptions + - start_e: a tuple of ints representing the initial sizes of _oldest_exceptions and _newest_exceptions + - end_e: a tuple of ints representing the expected sizes of _oldest_exceptions and _newest_exceptions + """ + from collections import deque + + input_list = [RuntimeError(f"mock {i}") for i in range(in_e)] + mock_batcher = mock.Mock() + mock_batcher._oldest_exceptions = [ + RuntimeError(f"starting mock {i}") for i in range(start_e[0]) + ] + mock_batcher._newest_exceptions = deque( + [RuntimeError(f"starting mock {i}") for i in range(start_e[1])], + maxlen=limit, + ) + mock_batcher._exception_list_limit = limit + mock_batcher._exceptions_since_last_raise = 0 + self._get_target_class()._add_exceptions(mock_batcher, input_list) + assert len(mock_batcher._oldest_exceptions) == end_e[0] + assert len(mock_batcher._newest_exceptions) == end_e[1] + assert mock_batcher._exceptions_since_last_raise == in_e + oldest_list_diff = end_e[0] - start_e[0] + newest_list_diff = min(max(in_e - oldest_list_diff, 0), limit) + for i in range(oldest_list_diff): + assert mock_batcher._oldest_exceptions[i + start_e[0]] == input_list[i] + for i in range(1, newest_list_diff + 1): + assert mock_batcher._newest_exceptions[-i] == input_list[-i] + + @pytest.mark.parametrize( + "input_retryables,expected_retryables", + [ + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), + ], + ) + def test_customizable_retryable_errors(self, input_retryables, expected_retryables): + """Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer.""" + with mock.patch.object( + google.api_core.retry, "if_exception_type" + ) as predicate_builder_mock: + with mock.patch.object( + CrossSync._Sync_Impl, "retry_target" + ) as retry_fn_mock: + table = None + with mock.patch("asyncio.create_task"): + table = CrossSync._Sync_Impl.Table(mock.Mock(), "instance", "table") + with self._make_one( + table, batch_retryable_errors=input_retryables + ) as instance: + assert instance._retryable_errors == expected_retryables + expected_predicate = expected_retryables.__contains__ + predicate_builder_mock.return_value = expected_predicate + retry_fn_mock.side_effect = RuntimeError("stop early") + mutation = self._make_mutation(count=1, size=1) + instance._execute_mutate_rows([mutation]) + predicate_builder_mock.assert_called_once_with( + *expected_retryables, _MutateRowsIncomplete + ) + retry_call_args = retry_fn_mock.call_args_list[0].args + assert retry_call_args[1] is expected_predicate diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync/test_read_rows_acceptance.py new file mode 100644 index 000000000..dcdd7d66c --- /dev/null +++ b/tests/unit/data/_sync/test_read_rows_acceptance.py @@ -0,0 +1,327 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by CrossSync. Do not edit manually. +from __future__ import annotations +import os +import warnings +import pytest +import mock +from itertools import zip_longest +from google.cloud.bigtable_v2 import ReadRowsResponse +from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.row import Row +from ...v2_client.test_row_merger import ReadRowsTest, TestFile +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + + +class TestReadRowsAcceptance: + @staticmethod + def _get_operation_class(): + return CrossSync._Sync_Impl._ReadRowsOperation + + @staticmethod + def _get_client_class(): + return CrossSync._Sync_Impl.DataClient + + def parse_readrows_acceptance_tests(): + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "../read-rows-acceptance-test.json") + with open(filename) as json_file: + test_json = TestFile.from_json(json_file.read()) + return test_json.read_rows_tests + + @staticmethod + def extract_results_from_row(row: Row): + results = [] + for family, col, cells in row.items(): + for cell in cells: + results.append( + ReadRowsTest.Result( + row_key=row.row_key, + family_name=family, + qualifier=col, + timestamp_micros=cell.timestamp_ns // 1000, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + ) + return results + + @staticmethod + def _coro_wrapper(stream): + return stream + + def _process_chunks(self, *chunks): + def _row_stream(): + yield ReadRowsResponse(chunks=chunks) + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + results = [] + for row in merger: + results.append(row) + return results + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + def test_row_merger_scenario(self, test_case: ReadRowsTest): + def _scenerio_stream(): + for chunk in test_case.chunks: + yield ReadRowsResponse(chunks=[chunk]) + + try: + results = [] + instance = mock.Mock() + instance._last_yielded_row_key = None + instance._remaining_count = None + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_scenerio_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + for row in merger: + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + @pytest.mark.parametrize( + "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description + ) + def test_read_rows_scenario(self, test_case: ReadRowsTest): + def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): + from google.cloud.bigtable_v2 import ReadRowsResponse + + class mock_stream: + def __init__(self, chunk_list): + self.chunk_list = chunk_list + self.idx = -1 + + def __aiter__(self): + return self + + def __iter__(self): + return self + + def __anext__(self): + self.idx += 1 + if len(self.chunk_list) > self.idx: + chunk = self.chunk_list[self.idx] + return ReadRowsResponse(chunks=[chunk]) + raise CrossSync._Sync_Impl.StopIteration + + def __next__(self): + return self.__anext__() + + def cancel(self): + pass + + return mock_stream(chunk_list) + + with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + client = self._get_client_class()() + try: + table = client.get_table("instance", "table") + results = [] + with mock.patch.object( + table.client._gapic_client, "read_rows" + ) as read_rows: + read_rows.return_value = _make_gapic_stream(test_case.chunks) + for row in table.read_rows_stream(query={}): + for cell in row: + cell_result = ReadRowsTest.Result( + row_key=cell.row_key, + family_name=cell.family, + qualifier=cell.qualifier, + timestamp_micros=cell.timestamp_micros, + value=cell.value, + label=cell.labels[0] if cell.labels else "", + ) + results.append(cell_result) + except InvalidChunk: + results.append(ReadRowsTest.Result(error=True)) + finally: + client.close() + for expected, actual in zip_longest(test_case.results, results): + assert actual == expected + + def test_out_of_order_rows(self): + def _row_stream(): + yield ReadRowsResponse(last_scanned_row_key=b"a") + + instance = mock.Mock() + instance._remaining_count = None + instance._last_yielded_row_key = b"b" + chunker = self._get_operation_class().chunk_stream( + instance, self._coro_wrapper(_row_stream()) + ) + merger = self._get_operation_class().merge_rows(chunker) + with pytest.raises(InvalidChunk): + for _ in merger: + pass + + def test_bare_reset(self): + first_chunk = ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk( + row_key=b"a", family_name="f", qualifier=b"q", value=b"v" + ) + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, row_key=b"a") + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, family_name="f") + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, qualifier=b"q") + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, timestamp_micros=1000) + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, labels=["a"]) + ), + ) + with pytest.raises(InvalidChunk): + self._process_chunks( + first_chunk, + ReadRowsResponse.CellChunk( + ReadRowsResponse.CellChunk(reset_row=True, value=b"v") + ), + ) + + def test_missing_family(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + qualifier=b"q", + timestamp_micros=1000, + value=b"v", + commit_row=True, + ) + ) + + def test_mid_cell_row_key_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(row_key=b"b", value=b"v", commit_row=True), + ) + + def test_mid_cell_family_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + family_name="f2", value=b"v", commit_row=True + ), + ) + + def test_mid_cell_qualifier_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + qualifier=b"q2", value=b"v", commit_row=True + ), + ) + + def test_mid_cell_timestamp_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk( + timestamp_micros=2000, value=b"v", commit_row=True + ), + ) + + def test_mid_cell_labels_change(self): + with pytest.raises(InvalidChunk): + self._process_chunks( + ReadRowsResponse.CellChunk( + row_key=b"a", + family_name="f", + qualifier=b"q", + timestamp_micros=1000, + value_size=2, + value=b"v", + ), + ReadRowsResponse.CellChunk(labels=["b"], value=b"v", commit_row=True), + ) From 3bb6c6f1570da4e8e741ef13f6b4c2973d58df51 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 15:03:42 -0600 Subject: [PATCH 204/360] added files to __init__.py --- google/cloud/bigtable/data/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index 66fe3479b..c660a5130 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -17,8 +17,10 @@ from google.cloud.bigtable.data._async.client import BigtableDataClientAsync from google.cloud.bigtable.data._async.client import TableAsync - from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync +from google.cloud.bigtable.data._sync.client import BigtableDataClient +from google.cloud.bigtable.data._sync.client import Table +from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.read_rows_query import RowRange @@ -51,6 +53,9 @@ "BigtableDataClientAsync", "TableAsync", "MutationsBatcherAsync", + "BigtableDataClient", + "Table", + "MutationsBatcher", "RowKeySamples", "ReadRowsQuery", "RowRange", From 4a041c818ac9364210841d04ecd3fcc56e7c8eba Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 15:03:56 -0600 Subject: [PATCH 205/360] added cross_sync to path for transformers --- .cross_sync/transformers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index f630e29eb..6b7aa26b5 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -15,8 +15,10 @@ import ast -from google.cloud.bigtable.data._sync.cross_sync import CrossSync -from google.cloud.bigtable.data._sync.cross_sync_decorators import AstDecorator, ExportSync +import sys +# add cross_sync to path +sys.path.append("google/cloud/bigtable/data/_sync") +from cross_sync_decorators import AstDecorator, ExportSync from generate import CrossSyncOutputFile From cbc36dd603b7020249d533e7670aa081b9c0ea4b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 15:17:46 -0600 Subject: [PATCH 206/360] support table in helpers --- google/cloud/bigtable/data/_helpers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index a8113cc4a..3ed6422ab 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -29,6 +29,7 @@ if TYPE_CHECKING: import grpc from google.cloud.bigtable.data import TableAsync + from google.cloud.bigtable.data import Table """ Helper functions used in various places in the library. @@ -137,7 +138,7 @@ def _retry_exception_factory( def _get_timeouts( operation: float | TABLE_DEFAULT, attempt: float | None | TABLE_DEFAULT, - table: "TableAsync", + table: "TableAsync" | "Table", ) -> tuple[float, float]: """ Convert passed in timeout values to floats, using table defaults if necessary. @@ -208,7 +209,7 @@ def _validate_timeouts( def _get_retryable_errors( call_codes: Sequence["grpc.StatusCode" | int | type[Exception]] | TABLE_DEFAULT, - table: "TableAsync", + table: "TableAsync" | "Table", ) -> list[type[Exception]]: """ Convert passed in retryable error codes to a list of exception types. From 03b60ed097b238ffba353040b2c2ba7cd60af857 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 15:32:56 -0600 Subject: [PATCH 207/360] removed quoted cross sync classes --- google/cloud/bigtable/data/_async/_mutate_rows.py | 4 ++-- google/cloud/bigtable/data/_async/_read_rows.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index f1c016e4c..7afc65816 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -71,8 +71,8 @@ class _MutateRowsOperationAsync: @CrossSync.convert def __init__( self, - gapic_client: "CrossSync.GapicClient", - table: "CrossSync.Table", + gapic_client: CrossSync.GapicClient, + table: CrossSync.Table, mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 2fe48e9e9..cb139a829 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -75,7 +75,7 @@ class _ReadRowsOperationAsync: def __init__( self, query: ReadRowsQuery, - table: "CrossSync.Table", + table: CrossSync.Table, operation_timeout: float, attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), From 3762a036d6ff9bf763b3fd7cf4819cd5b3a2c2a3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 15:59:26 -0600 Subject: [PATCH 208/360] fixed mypy issues --- .../data/_sync/cross_sync_decorators.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py index 2ca763d95..4420345b3 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync_decorators.py @@ -98,7 +98,7 @@ def sync_ast_transform( return wrapped_node @classmethod - def get_for_node(cls, node: ast.Call) -> "AstDecorator": + def get_for_node(cls, node: ast.Call | ast.Attribute | ast.Name) -> "AstDecorator": """ Build an AstDecorator instance from an ast decorator node @@ -116,9 +116,14 @@ def get_for_node(cls, node: ast.Call) -> "AstDecorator": """ import ast - if "CrossSync" in ast.dump(node): - decorator_name = node.func.attr if hasattr(node, "func") else node.attr - formatted_name = decorator_name.replace("_", "").lower() + # expect decorators in format @CrossSync. + # (i.e. should be an ast.Call or an ast.Attribute) + root_attr = node.func if isinstance(node, ast.Call) else node + if not isinstance(root_attr, ast.Attribute): + raise ValueError("Unexpected decorator format") + # extract the module and decorator names + if "CrossSync" in ast.dump(root_attr): + decorator_name = root_attr.attr got_kwargs = ( {kw.arg: cls._convert_ast_to_py(kw.value) for kw in node.keywords} if hasattr(node, "keywords") @@ -129,14 +134,17 @@ def get_for_node(cls, node: ast.Call) -> "AstDecorator": if hasattr(node, "args") else [] ) + # convert to standardized representation + formatted_name = decorator_name.replace("_", "").lower() for subclass in cls.__subclasses__(): if subclass.__name__.lower() == formatted_name: return subclass(*got_args, **got_kwargs) raise ValueError(f"Unknown decorator encountered: {decorator_name}") - raise ValueError("Not a CrossSync decorator") + else: + raise ValueError("Not a CrossSync decorator") @classmethod - def _convert_ast_to_py(cls, ast_node: ast.expr|None) -> Any: + def _convert_ast_to_py(cls, ast_node: ast.expr | None) -> Any: """ Helper to convert ast primitives to python primitives. Used when unwrapping arguments """ @@ -313,7 +321,8 @@ class Pytest(AstDecorator): """ Used in place of pytest.mark.asyncio to mark tests - Will be stripped from sync output + When generating sync version, also runs rm_aio to remove async keywords from + entire test function """ def async_decorator(self): @@ -343,7 +352,7 @@ def __init__(self, *args, **kwargs): self._kwargs = kwargs def async_decorator(self): - import pytest_asyncio + import pytest_asyncio # type: ignore return lambda f: pytest_asyncio.fixture(*self._args, **self._kwargs)(f) From 274bd36a50c2b3d66223d3d4da3c3572597220e1 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 16:10:06 -0600 Subject: [PATCH 209/360] moved cross_sync into own directory --- .cross_sync/transformers.py | 17 ++++++++-------- google/cloud/bigtable/data/_sync/__init__.py | 20 +++++++++++++++++++ .../_decorators.py} | 0 .../_sync/{ => _cross_sync}/cross_sync.py | 2 +- 4 files changed, 30 insertions(+), 9 deletions(-) create mode 100644 google/cloud/bigtable/data/_sync/__init__.py rename google/cloud/bigtable/data/_sync/{cross_sync_decorators.py => _cross_sync/_decorators.py} (100%) rename google/cloud/bigtable/data/_sync/{ => _cross_sync}/cross_sync.py (99%) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 6b48421d3..117a25f08 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -15,8 +15,10 @@ import ast -from google.cloud.bigtable.data._sync.cross_sync import CrossSync -from google.cloud.bigtable.data._sync.cross_sync_decorators import AstDecorator, ExportSync +import sys +# add cross_sync to path +sys.path.append("google/cloud/bigtable/data/_sync/_cross_sync") +from _decorators import AstDecorator, ExportSync from generate import CrossSyncOutputFile @@ -148,18 +150,15 @@ def visit_AsyncWith(self, node): Async with statements are not fully wrapped by calls """ found_rmaio = False - new_items = [] for item in node.items: if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute) and isinstance(item.context_expr.func.value, ast.Name) and \ item.context_expr.func.attr == "rm_aio" and "CrossSync" in item.context_expr.func.value.id: found_rmaio = True - new_items.append(item.context_expr.args[0]) - else: - new_items.append(item) + break if found_rmaio: new_node = ast.copy_location( ast.With( - [self.generic_visit(item) for item in new_items], + [self.generic_visit(item) for item in node.items], [self.generic_visit(stmt) for stmt in node.body], ), node, @@ -177,7 +176,7 @@ def visit_AsyncFor(self, node): return ast.copy_location( ast.For( self.visit(node.target), - self.visit(node.iter.args[0]), + self.visit(it), [self.visit(stmt) for stmt in node.body], [self.visit(stmt) for stmt in node.orelse], ), @@ -204,6 +203,8 @@ def visit_AsyncFunctionDef(self, node): node = handler.sync_ast_transform(node, globals()) if node is None: return None + # recurse to any nested functions + node = self.generic_visit(node) except ValueError: # keep unknown decorators node.decorator_list.append(decorator) diff --git a/google/cloud/bigtable/data/_sync/__init__.py b/google/cloud/bigtable/data/_sync/__init__.py new file mode 100644 index 000000000..476b531db --- /dev/null +++ b/google/cloud/bigtable/data/_sync/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.bigtable.data._sync._cross_sync import CrossSync + + +__all__ = [ + "CrossSync", +] diff --git a/google/cloud/bigtable/data/_sync/cross_sync_decorators.py b/google/cloud/bigtable/data/_sync/_cross_sync/_decorators.py similarity index 100% rename from google/cloud/bigtable/data/_sync/cross_sync_decorators.py rename to google/cloud/bigtable/data/_sync/_cross_sync/_decorators.py diff --git a/google/cloud/bigtable/data/_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/_cross_sync/cross_sync.py similarity index 99% rename from google/cloud/bigtable/data/_sync/cross_sync.py rename to google/cloud/bigtable/data/_sync/_cross_sync/cross_sync.py index 70a3950d0..89f6e5ec4 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/_cross_sync/cross_sync.py @@ -35,7 +35,7 @@ import queue import threading import time -from .cross_sync_decorators import ( +from ._decorators import ( ExportSync, Convert, DropMethod, From 56116140c9e091d025137210e6080fff55a171ca Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 16:37:08 -0600 Subject: [PATCH 210/360] added docstrings --- .cross_sync/generate.py | 7 +++++ .cross_sync/transformers.py | 13 ++++++++ .../data/_sync/_cross_sync/_decorators.py | 6 +++- .../data/_sync/_cross_sync/cross_sync.py | 31 +++++++++++++++++++ 4 files changed, 56 insertions(+), 1 deletion(-) diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index c92d700a2..2523e4d0d 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -15,6 +15,13 @@ from typing import Sequence import ast from dataclasses import dataclass, field +""" +Entrypoint for initiating an async -> sync conversion using CrossSync + +Finds all python files rooted in a given directory, and uses +transformers.CrossSyncClassDecoratorHandler to handle any CrossSync class +decorators found in the files. +""" @dataclass diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 117a25f08..6f5047d50 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -11,6 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Provides a set of ast.NodeTransformer subclasses that are composed to generate +async code into sync code. + +At a high level: +- The main entrypoint is CrossSyncClassDecoratorHandler, which is used to find classes +annotated with @CrossSync.export_sync. +- SymbolReplacer is used to swap out CrossSync.X with CrossSync._Sync_Impl.X +- RmAioFunctions is then called on the class, to strip out asyncio keywords +marked with CrossSync.rm_aio (using AsyncToSync to handle the actual transformation) +- Finally, CrossSyncMethodDecoratorHandler is called to find methods annotated +with AstDecorators, and call decorator.sync_ast_transform on each one to fully transform the class. +""" from __future__ import annotations import ast diff --git a/google/cloud/bigtable/data/_sync/_cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/_cross_sync/_decorators.py index 4420345b3..e6728f19d 100644 --- a/google/cloud/bigtable/data/_sync/_cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/_cross_sync/_decorators.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Contains a set of AstDecorator classes, which define the behavior of CrossSync decorators. +Each AstDecorator class is used through @CrossSync. +""" from __future__ import annotations from typing import TYPE_CHECKING @@ -23,7 +27,7 @@ class AstDecorator: """ Helper class for CrossSync decorators used for guiding ast transformations. - CrossSync decorations are accessed in two ways: + AstDecorators are accessed in two ways: 1. The decorations are used directly as method decorations in the async client, wrapping existing classes and methods 2. The decorations are read back when processing the AST transformations when diff --git a/google/cloud/bigtable/data/_sync/_cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/_cross_sync/cross_sync.py index 89f6e5ec4..0ad569d33 100644 --- a/google/cloud/bigtable/data/_sync/_cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/_cross_sync/cross_sync.py @@ -12,6 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. # +""" +CrossSync provides a toolset for sharing logic between async and sync codebases, including: +- A set of decorators for annotating async classes and functions + (@CrossSync.export_sync, @CrossSync.convert, @CrossSync.drop_method, ...) +- A set of wrappers to wrap common objects and types that have corresponding async and sync implementations + (CrossSync.Queue, CrossSync.Condition, CrossSync.Future, ...) +- A set of function implementations for common async operations that can be used in both async and sync codebases + (CrossSync.gather_partials, CrossSync.wait, CrossSync.condition_wait, ...) +- CrossSync.rm_aio(), which is used to annotate regions of the code containing async keywords to strip + +A separate module will use CrossSync annotations to generate a corresponding sync +class based on a decorated async class. + +Usage Example: +```python +@CrossSync.export_sync(path="path/to/sync_module.py") + + @CrossSync.convert + async def async_func(self, arg: int) -> int: + await CrossSync.sleep(1) + return arg +``` +""" + + from __future__ import annotations from typing import ( @@ -223,6 +248,12 @@ def verify_async_event_loop() -> None: @staticmethod def rm_aio(statement: Any) -> Any: + """ + Used to annotate regions of the code containing async keywords to strip + + All async keywords inside an rm_aio call are removed, along with + `async with` and `async for` statements containing CrossSync.rm_aio() in the body + """ return statement class _Sync_Impl: From ad95748bc78b3adb8c930cbafe155beddac86943 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 17:11:23 -0600 Subject: [PATCH 211/360] added README --- .cross_sync/README.md | 72 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 .cross_sync/README.md diff --git a/.cross_sync/README.md b/.cross_sync/README.md new file mode 100644 index 000000000..d8aff8731 --- /dev/null +++ b/.cross_sync/README.md @@ -0,0 +1,72 @@ +# CrossSync + +CrossSync provides a simple way to share logic between async and sync code. +It is made up of a small library that provides: +1. a set of shims that provide a shared sync/async API surface +2. annotations that are used to guide generation of a sync version from an async class + +Using CrossSync, the async code is treated as the source of truth, and sync code is generated from it. + +## Usage + +### CrossSync Shims + +Many Asyncio components have direct, 1:1 threaded counterparts for use in non-asyncio code. CrossSync +provides a compatibility layer that works with both + +| CrossSync | Asyncio Version | Sync Version | +| CrossSync.Queue | asyncio.Queue | queue.Queue | +| CrossSync.Condition | asyncio.Condition | threading.Condition | +| CrossSync.Future | asyncio.Future | Concurrent.futures.Future | +| CrossSync.Task | asyncio.Task | Concurrent.futures.Future | +| CrossSync.Event | asyncio.Event | threading.Event | +| CrossSync.Semaphore | asyncio.Semaphore | threading.Semaphore | +| CrossSync.Awaitable | typing.Awaitable | typing.Union (no-op type) | +| CrossSync.Iterable | typing.AsyncIterable | typing.Iterable | +| CrossSync.Iterator | typing.AsyncIterator | typing.Iterator | +| CrossSync.Generator | typing.AsyncGenerator | typing.Generator | +| CrossSync.Retry | google.api_core.retry.AsyncRetry | google.api_core.retry.Retry | +| CrossSync.StopIteration | StopAsyncIteration | StopIteration | +| CrossSync.Mock | unittest.mock.AsyncMock | unittest.mock.Mock | + +Custom aliases can be added using `CrossSync.add_mapping(class, name)` + +Additionally, CrossSync provides method implementations that work equivalently in async and sync code: +- `CrossSync.sleep()` +- `CrossSync.gather_partials()` +- `CrossSync.wait()` +- `CrossSync.condition_wait()` +- `CrossSync,event_wait()` +- `CrossSync.create_task()` +- `CrossSync.retry_target()`` +- `CrossSync.retry_target_stream()` + +### Annotations + +CrossSync provides a set of annotations to mark up async classes, to guide the generation of sync code. + +- `@CrossSync.export_sync` + - marks classes for conversion, along with an output file path + - if add_mapping is included, the async and sync classes can be accessed using a shared CrossSync.X alias +- `@CrossSync.convert` + - marks async functions for conversion +- `@CrossSync.drop_method` + - marks functions that should not be included in sync output +- `@CrossSync.pytest` + - marks test functions. Test functions automatically have all async keywords stripped (i.e., rm_aio is unneeded) +- `CrossSync.add_mapping` + - manually registers a new CrossSync.X alias, for custom types +- `CrossSync.rm_aio` + - Marks regions of the code that include asyncio keywords that should be stripped during generation + +### Code Generation + +Generation can be initiated using `python .cross_sync/generate.py .` +from the root of the project. This will find all classes with the `@CrossSync.export_sync` annotation +in both `/google` and `/tests` directories, and save them to their specified output paths + +## Architecture + +CrossSync is made up of two parts: +- the runtime shims and annotations live in `/google/cloud/bigtable/_sync/_cross_sync` +- the code generation logic lives in `/.cross_sync/` in the repo root From cad416dc84924705b7a628e50657349fb92a7bdc Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 22 Jul 2024 16:14:17 -0700 Subject: [PATCH 212/360] fixed README formatting --- .cross_sync/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.cross_sync/README.md b/.cross_sync/README.md index d8aff8731..563fccb3b 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -15,6 +15,7 @@ Many Asyncio components have direct, 1:1 threaded counterparts for use in non-as provides a compatibility layer that works with both | CrossSync | Asyncio Version | Sync Version | +| --- | --- | --- | | CrossSync.Queue | asyncio.Queue | queue.Queue | | CrossSync.Condition | asyncio.Condition | threading.Condition | | CrossSync.Future | asyncio.Future | Concurrent.futures.Future | @@ -38,7 +39,7 @@ Additionally, CrossSync provides method implementations that work equivalently i - `CrossSync.condition_wait()` - `CrossSync,event_wait()` - `CrossSync.create_task()` -- `CrossSync.retry_target()`` +- `CrossSync.retry_target()` - `CrossSync.retry_target_stream()` ### Annotations From 9a1965853efa923187bb678468a514bf205ef477 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 23 Jul 2024 08:43:54 -0600 Subject: [PATCH 213/360] added rm_aio to pytest and convert decorators --- google/cloud/bigtable/data/_sync/__init__.py | 2 +- .../data/_sync/_cross_sync/_decorators.py | 38 ++++++++++++------- .../data/_sync/_cross_sync/cross_sync.py | 2 +- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/__init__.py b/google/cloud/bigtable/data/_sync/__init__.py index 476b531db..99779a72f 100644 --- a/google/cloud/bigtable/data/_sync/__init__.py +++ b/google/cloud/bigtable/data/_sync/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.cloud.bigtable.data._sync._cross_sync import CrossSync +from google.cloud.bigtable.data._sync._cross_sync.cross_sync import CrossSync __all__ = [ diff --git a/google/cloud/bigtable/data/_sync/_cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/_cross_sync/_decorators.py index e6728f19d..2788ffec4 100644 --- a/google/cloud/bigtable/data/_sync/_cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/_cross_sync/_decorators.py @@ -182,14 +182,12 @@ class ExportSync(AstDecorator): def __init__( self, - path: str, # path to output the generated sync class + path: str, *, - replace_symbols: dict[str, str] - | None = None, # replace symbols in the generated sync class - mypy_ignore: Sequence[str] = (), # set of mypy errors to ignore - include_file_imports: bool = True, # include imports from the file in the generated sync class - add_mapping_for_name: str - | None = None, # add a new attribute to CrossSync with the given name + replace_symbols: dict[str, str] | None = None, + mypy_ignore: Sequence[str] = (), + include_file_imports: bool = True, + add_mapping_for_name: str | None = None, ): self.path = path self.replace_symbols = replace_symbols @@ -269,18 +267,21 @@ class Convert(AstDecorator): Args: sync_name: use a new name for the sync method replace_symbols: a dict of symbols and replacements to use when generating sync method + rm_aio: if True, automatically strip all asyncio keywords from method. If False, + only the signature `async def` is stripped. Other keywords must be wrapped in + CrossSync.rm_aio() calls to be removed. """ def __init__( self, *, - sync_name: str | None = None, # use a new name for the sync method - replace_symbols: dict[ - str, str - ] = {}, # replace symbols in the generated sync method + sync_name: str | None = None, + replace_symbols: dict[str, str] | None = None, + rm_aio: bool = False, ): self.sync_name = sync_name self.replace_symbols = replace_symbols + self.rm_aio = rm_aio def sync_ast_transform(self, wrapped_node, transformers_globals): """ @@ -302,6 +303,9 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): # update name if specified if self.sync_name: wrapped_node.name = self.sync_name + # strip async keywords if specified + if self.rm_aio: + wrapped_node = transformers_globals["AsyncToSync"]().visit(wrapped_node) # update arbitrary symbols if specified if self.replace_symbols: replacer = transformers_globals["SymbolReplacer"] @@ -327,8 +331,15 @@ class Pytest(AstDecorator): When generating sync version, also runs rm_aio to remove async keywords from entire test function + + Args: + rm_aio: if True, automatically strip all asyncio keywords from test code. + Defaults to True, to simplify test code generation. """ + def __init__(self, rm_aio=True): + self.rm_aio = rm_aio + def async_decorator(self): import pytest @@ -338,8 +349,9 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): """ convert async to sync """ - converted = transformers_globals["AsyncToSync"]().visit(wrapped_node) - return converted + if self.rm_aio: + wrapped_node = transformers_globals["AsyncToSync"]().visit(wrapped_node) + return wrapped_node class PytestFixture(AstDecorator): diff --git a/google/cloud/bigtable/data/_sync/_cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/_cross_sync/cross_sync.py index 0ad569d33..2fee0e536 100644 --- a/google/cloud/bigtable/data/_sync/_cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/_cross_sync/cross_sync.py @@ -14,7 +14,7 @@ # """ CrossSync provides a toolset for sharing logic between async and sync codebases, including: -- A set of decorators for annotating async classes and functions +- A set of decorators for annotating async classes and functions (@CrossSync.export_sync, @CrossSync.convert, @CrossSync.drop_method, ...) - A set of wrappers to wrap common objects and types that have corresponding async and sync implementations (CrossSync.Queue, CrossSync.Condition, CrossSync.Future, ...) From b64c0c79758b96cbe3243bffc765aa62f49cbc36 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 23 Jul 2024 09:03:29 -0600 Subject: [PATCH 214/360] create paths when writing sync outputs --- .cross_sync/generate.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index 2523e4d0d..d93838e59 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -96,6 +96,9 @@ def render(self, with_black=True, save_to_disk=False) -> str: mode=black.FileMode(), ) if save_to_disk: + # create parent paths if needed + import os + os.makedirs(os.path.dirname(self.file_path), exist_ok=True) with open(self.file_path, "w") as f: f.write(full_str) return full_str From 223f3371d2918fb70485961243e0c95d7dbb79f4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 23 Jul 2024 09:05:57 -0600 Subject: [PATCH 215/360] moved files --- .cross_sync/transformers.py | 2 +- google/cloud/bigtable/data/_sync/{ => cross_sync}/__init__.py | 2 +- .../data/_sync/{_cross_sync => cross_sync}/_decorators.py | 0 .../data/_sync/{_cross_sync => cross_sync}/cross_sync.py | 0 4 files changed, 2 insertions(+), 2 deletions(-) rename google/cloud/bigtable/data/_sync/{ => cross_sync}/__init__.py (88%) rename google/cloud/bigtable/data/_sync/{_cross_sync => cross_sync}/_decorators.py (100%) rename google/cloud/bigtable/data/_sync/{_cross_sync => cross_sync}/cross_sync.py (100%) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 6f5047d50..76439b80e 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -30,7 +30,7 @@ import sys # add cross_sync to path -sys.path.append("google/cloud/bigtable/data/_sync/_cross_sync") +sys.path.append("google/cloud/bigtable/data/_sync/cross_sync") from _decorators import AstDecorator, ExportSync from generate import CrossSyncOutputFile diff --git a/google/cloud/bigtable/data/_sync/__init__.py b/google/cloud/bigtable/data/_sync/cross_sync/__init__.py similarity index 88% rename from google/cloud/bigtable/data/_sync/__init__.py rename to google/cloud/bigtable/data/_sync/cross_sync/__init__.py index 99779a72f..77a9ddae9 100644 --- a/google/cloud/bigtable/data/_sync/__init__.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.cloud.bigtable.data._sync._cross_sync.cross_sync import CrossSync +from .cross_sync import CrossSync __all__ = [ diff --git a/google/cloud/bigtable/data/_sync/_cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py similarity index 100% rename from google/cloud/bigtable/data/_sync/_cross_sync/_decorators.py rename to google/cloud/bigtable/data/_sync/cross_sync/_decorators.py diff --git a/google/cloud/bigtable/data/_sync/_cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py similarity index 100% rename from google/cloud/bigtable/data/_sync/_cross_sync/cross_sync.py rename to google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py From 3600a63886db1eb2021ec23dfc36deb45b36d454 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 24 Jul 2024 14:56:47 -0600 Subject: [PATCH 216/360] moved add_mapping into metaclass --- .../data/_sync/cross_sync/_mapping_meta.py | 64 +++++++++++++++++++ .../data/_sync/cross_sync/cross_sync.py | 47 +------------- 2 files changed, 67 insertions(+), 44 deletions(-) create mode 100644 google/cloud/bigtable/data/_sync/cross_sync/_mapping_meta.py diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_mapping_meta.py b/google/cloud/bigtable/data/_sync/cross_sync/_mapping_meta.py new file mode 100644 index 000000000..5312708cc --- /dev/null +++ b/google/cloud/bigtable/data/_sync/cross_sync/_mapping_meta.py @@ -0,0 +1,64 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations +from typing import Any + + +class MappingMeta(type): + """ + Metaclass to provide add_mapping functionality, allowing users to add + custom attributes to derived classes at runtime. + + Using a metaclass allows us to share functionality between CrossSync + and CrossSync._Sync_Impl, and it works better with mypy checks than + monkypatching + """ + + # list of attributes that can be added to the derived class at runtime + _runtime_replacements: dict[tuple[MappingMeta, str], Any] = {} + + def add_mapping(cls: MappingMeta, name: str, value: Any): + """ + Add a new attribute to the class, for replacing library-level symbols + + Raises: + - AttributeError if the attribute already exists with a different value + """ + key = (cls, name) + old_value = cls._runtime_replacements.get(key) + if old_value is None: + cls._runtime_replacements[key] = value + elif old_value != value: + raise AttributeError(f"Conflicting assignments for CrossSync.{name}") + + def add_mapping_decorator(cls: MappingMeta, name: str): + """ + Exposes add_mapping as a class decorator + """ + + def decorator(wrapped_cls): + cls.add_mapping(name, wrapped_cls) + return wrapped_cls + + return decorator + + def __getattr__(cls: MappingMeta, name: str): + """ + Retrieve custom attributes + """ + key = (cls, name) + found = cls._runtime_replacements.get(key) + if found is not None: + return found + raise AttributeError(f"CrossSync has no attribute {name}") diff --git a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py index 2fee0e536..04ac79c73 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py @@ -36,7 +36,6 @@ async def async_func(self, arg: int) -> int: ``` """ - from __future__ import annotations from typing import ( @@ -67,6 +66,7 @@ async def async_func(self, arg: int) -> int: Pytest, PytestFixture, ) +from ._mapping_meta import MappingMeta if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -74,7 +74,7 @@ async def async_func(self, arg: int) -> int: T = TypeVar("T") -class CrossSync: +class CrossSync(metaclass=MappingMeta): # support CrossSync.is_async to check if the current environment is async is_async = True @@ -105,23 +105,6 @@ class CrossSync: PytestFixture.decorator ) # decorate test methods to run with pytest fixture - # list of attributes that can be added to the CrossSync class at runtime - _runtime_replacements: set[Any] = set() - - @classmethod - def add_mapping(cls, name, value): - """ - Add a new attribute to the CrossSync class, for replacing library-level symbols - - Raises: - - AttributeError if the attribute already exists with a different value - """ - if not hasattr(cls, name): - cls._runtime_replacements.add(name) - elif value != getattr(cls, name): - raise AttributeError(f"Conflicting assignments for CrossSync.{name}") - setattr(cls, name, value) - @classmethod def Mock(cls, *args, **kwargs): """ @@ -256,7 +239,7 @@ def rm_aio(statement: Any) -> Any: """ return statement - class _Sync_Impl: + class _Sync_Impl(metaclass=MappingMeta): """ Provide sync versions of the async functions and types in CrossSync """ @@ -280,30 +263,6 @@ class _Sync_Impl: Iterator: TypeAlias = typing.Iterator Generator: TypeAlias = typing.Generator - _runtime_replacements: set[Any] = set() - - @classmethod - def add_mapping_decorator(cls, name): - def decorator(wrapped_cls): - cls.add_mapping(name, wrapped_cls) - return wrapped_cls - - return decorator - - @classmethod - def add_mapping(cls, name, value): - """ - Add a new attribute to the CrossSync class, for replacing library-level symbols - - Raises: - - AttributeError if the attribute already exists with a different value - """ - if not hasattr(cls, name): - cls._runtime_replacements.add(name) - elif value != getattr(cls, name): - raise AttributeError(f"Conflicting assignments for CrossSync.{name}") - setattr(cls, name, value) - @classmethod def Mock(cls, *args, **kwargs): # try/except added for compatibility with python < 3.8 From 6ce1b5cf33541140747da69618cc518d9f6bdc81 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 24 Jul 2024 15:39:50 -0600 Subject: [PATCH 217/360] fixed mypy issues --- .../bigtable/data/_async/_mutate_rows.py | 16 ++++++--------- .../cloud/bigtable/data/_async/_read_rows.py | 10 ++++++++-- google/cloud/bigtable/data/_async/client.py | 20 +++++++------------ .../bigtable/data/_async/mutations_batcher.py | 7 +++---- .../data/_sync/cross_sync/cross_sync.py | 2 +- 5 files changed, 25 insertions(+), 30 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 7afc65816..34bc72b4e 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -34,15 +34,11 @@ from google.cloud.bigtable.data.mutations import RowMutationEntry if CrossSync.is_async: - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) - - CrossSync.add_mapping("GapicClient", BigtableAsyncClient) + from google.cloud.bigtable_v2.services.bigtable.async_client import BigtableAsyncClient as GapicClientType + from google.cloud.bigtable.data._async.client import TableAsync as TableType else: - from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient - - CrossSync.add_mapping("GapicClient", BigtableClient) + from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient as GapicClientType # type: ignore + from google.cloud.bigtable.data._sync.client import Table as TableType # type: ignore @CrossSync.export_sync( @@ -71,8 +67,8 @@ class _MutateRowsOperationAsync: @CrossSync.convert def __init__( self, - gapic_client: CrossSync.GapicClient, - table: CrossSync.Table, + gapic_client: GapicClientType, + table: TableType, mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index cb139a829..04d9de9e7 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import Sequence +from typing import Sequence, TYPE_CHECKING from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB @@ -36,6 +36,12 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if TYPE_CHECKING: + if CrossSync.is_async: + from google.cloud.bigtable.data._async.client import TableAsync as TableType + else: + from google.cloud.bigtable.data._sync.client import Table as TableType # type: ignore + @CrossSync.export_sync( path="google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", @@ -75,7 +81,7 @@ class _ReadRowsOperationAsync: def __init__( self, query: ReadRowsQuery, - table: CrossSync.Table, + table: TableType, operation_timeout: float, attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index bc389f2d4..1676923c9 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -76,9 +76,7 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync if CrossSync.is_async: - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, - ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport as PooledTransportType from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( PooledChannel as AsyncPooledChannel, ) @@ -93,30 +91,26 @@ # define file-specific cross-sync replacements CrossSync.add_mapping("GapicClient", BigtableAsyncClient) - CrossSync.add_mapping("PooledTransport", PooledBigtableGrpcAsyncIOTransport) CrossSync.add_mapping("PooledChannel", AsyncPooledChannel) CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) CrossSync.add_mapping("MutationsBatcher", MutationsBatcherAsync) else: - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledBigtableGrpcTransport, - ) + from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport as PooledTransportType # type: ignore from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( PooledChannel, ) from google.cloud.bigtable_v2.services.bigtable.client import ( BigtableClient, ) - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation - from google.cloud.bigtable.data._sync.mutations_batcher import ( + from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation # type: ignore + from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation # type: ignore + from google.cloud.bigtable.data._sync.mutations_batcher import ( # type: ignore MutationsBatcher, ) # define file-specific cross-sync replacements CrossSync.add_mapping("GapicClient", BigtableClient) - CrossSync.add_mapping("PooledTransport", PooledBigtableGrpcTransport) CrossSync.add_mapping("PooledChannel", PooledChannel) CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperation) CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperation) @@ -168,7 +162,7 @@ def __init__( """ # set up transport in registry transport_str = f"bt-{self._client_version()}-{pool_size}" - transport = CrossSync.PooledTransport.with_fixed_size(pool_size) + transport = PooledTransportType.with_fixed_size(pool_size) BigtableClientMeta._transport_registry[transport_str] = transport # set up client info headers for veneer library client_info = DEFAULT_CLIENT_INFO @@ -200,7 +194,7 @@ def __init__( client_info=client_info, ) self._is_closed = CrossSync.Event() - self.transport = cast(CrossSync.PooledTransport, self._gapic_client.transport) + self.transport = cast(PooledTransportType, self._gapic_client.transport) # keep track of active instances to for warmup on channel refresh self._active_instances: Set[_WarmedInstanceKey] = set() # keep track of table objects associated with each instance diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 253bfea4a..d075768df 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -38,9 +38,9 @@ from google.cloud.bigtable.data.mutations import RowMutationEntry if CrossSync.is_async: - from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._async.client import TableAsync as TableType else: - from google.cloud.bigtable.data._sync.client import Table + from google.cloud.bigtable.data._sync.client import Table as TableType # type: ignore @CrossSync.export_sync( @@ -213,10 +213,9 @@ class MutationsBatcherAsync: Defaults to the Table's default_mutate_rows_retryable_errors. """ - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) def __init__( self, - table: TableAsync, + table: TableType, *, flush_interval: float | None = 5, flush_limit_mutation_count: int | None = 1000, diff --git a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py index 04ac79c73..6c21da5e0 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py @@ -230,7 +230,7 @@ def verify_async_event_loop() -> None: asyncio.get_running_loop() @staticmethod - def rm_aio(statement: Any) -> Any: + def rm_aio(statement: T) -> T: """ Used to annotate regions of the code containing async keywords to strip From 6ac4c718df6f2d11d21416dd7c7b53c345467a8d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 24 Jul 2024 16:17:50 -0600 Subject: [PATCH 218/360] ignore unneeded mypy check --- noxfile.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/noxfile.py b/noxfile.py index 5fb94526d..208f28e44 100644 --- a/noxfile.py +++ b/noxfile.py @@ -143,10 +143,9 @@ def mypy(session): "--check-untyped-defs", "--warn-unreachable", "--disallow-any-generics", - "--exclude", - "tests/system/v2_client", - "--exclude", - "tests/unit/v2_client", + "--exclude", "tests/system/v2_client", + "--exclude", "tests/unit/v2_client", + "--disable-error-code", "func-returns-value" # needed for CrossSync.rm_aio ) From d76cb5ca59e91a76ff5da381b39ec569af9578ad Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 24 Jul 2024 19:47:17 -0700 Subject: [PATCH 219/360] add mappings in __init__.py --- google/cloud/bigtable/data/__init__.py | 17 +++++++++++++ .../bigtable/data/_async/_mutate_rows.py | 1 - .../cloud/bigtable/data/_async/_read_rows.py | 1 - google/cloud/bigtable/data/_async/client.py | 25 ------------------- 4 files changed, 17 insertions(+), 27 deletions(-) diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index 66fe3479b..599d2b213 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -44,6 +44,23 @@ from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery +# setup custom CrossSync mappings for library +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( + PooledChannel as AsyncPooledChannel, +) +from google.cloud.bigtable_v2.services.bigtable.async_client import ( + BigtableAsyncClient, +) +from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync +from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync +from google.cloud.bigtable.data._sync.cross_sync import CrossSync +CrossSync.add_mapping("GapicClient", BigtableAsyncClient) +CrossSync.add_mapping("PooledChannel", AsyncPooledChannel) +CrossSync.add_mapping("PooledTransport", PooledBigtableGrpcAsyncIOTransport) +CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) +CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) + __version__: str = package_version.__version__ diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 34bc72b4e..a6540958a 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -43,7 +43,6 @@ @CrossSync.export_sync( path="google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", - add_mapping_for_name="_MutateRowsOperation", ) class _MutateRowsOperationAsync: """ diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 04d9de9e7..5611e3fc1 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -45,7 +45,6 @@ @CrossSync.export_sync( path="google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", - add_mapping_for_name="_ReadRowsOperation", ) class _ReadRowsOperationAsync: """ diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 1676923c9..2c495d7d9 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -77,45 +77,20 @@ if CrossSync.is_async: from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import PooledBigtableGrpcAsyncIOTransport as PooledTransportType - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel as AsyncPooledChannel, - ) - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync from google.cloud.bigtable.data._async.mutations_batcher import ( MutationsBatcherAsync, ) - # define file-specific cross-sync replacements - CrossSync.add_mapping("GapicClient", BigtableAsyncClient) - CrossSync.add_mapping("PooledChannel", AsyncPooledChannel) - CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) - CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) - CrossSync.add_mapping("MutationsBatcher", MutationsBatcherAsync) else: from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport as PooledTransportType # type: ignore - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledChannel, - ) - from google.cloud.bigtable_v2.services.bigtable.client import ( - BigtableClient, - ) from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation # type: ignore from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation # type: ignore from google.cloud.bigtable.data._sync.mutations_batcher import ( # type: ignore MutationsBatcher, ) - # define file-specific cross-sync replacements - CrossSync.add_mapping("GapicClient", BigtableClient) - CrossSync.add_mapping("PooledChannel", PooledChannel) - CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperation) - CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperation) - CrossSync.add_mapping("MutationsBatcher", MutationsBatcher) - if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery From 449bb3935e162e56eb7380a18e48b93cf716e5ab Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 12 Aug 2024 18:04:41 -0700 Subject: [PATCH 220/360] changed mapping method --- google/cloud/bigtable/data/__init__.py | 18 +++++++- .../cloud/bigtable/data/_sync/_mutate_rows.py | 17 ++++---- .../cloud/bigtable/data/_sync/_read_rows.py | 11 +++-- google/cloud/bigtable/data/_sync/client.py | 43 ++----------------- .../bigtable/data/_sync/mutations_batcher.py | 6 +-- tests/unit/data/_sync/test_client.py | 4 +- 6 files changed, 42 insertions(+), 57 deletions(-) diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index 2d5be19a8..d3e30ee2f 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -59,14 +59,30 @@ ) from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledBigtableGrpcTransport, +) +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledChannel, +) +from google.cloud.bigtable_v2.services.bigtable.client import ( + BigtableClient, +) +from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation +from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation + from google.cloud.bigtable.data._sync.cross_sync import CrossSync CrossSync.add_mapping("GapicClient", BigtableAsyncClient) +CrossSync._Sync_Impl.add_mapping("GapicClient", BigtableClient) CrossSync.add_mapping("PooledChannel", AsyncPooledChannel) +CrossSync._Sync_Impl.add_mapping("PooledChannel", PooledChannel) CrossSync.add_mapping("PooledTransport", PooledBigtableGrpcAsyncIOTransport) +CrossSync._Sync_Impl.add_mapping("PooledTransport", PooledBigtableGrpcTransport) CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) +CrossSync._Sync_Impl.add_mapping("_ReadRowsOperation", _ReadRowsOperation) CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) - +CrossSync._Sync_Impl.add_mapping("_MutateRowsOperation", _MutateRowsOperation) __version__: str = package_version.__version__ diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index 8e01f95e8..ee3b81fdc 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -31,17 +31,16 @@ if CrossSync._Sync_Impl.is_async: from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, + BigtableAsyncClient as GapicClientType, ) - - CrossSync._Sync_Impl.add_mapping("GapicClient", BigtableAsyncClient) + from google.cloud.bigtable.data._async.client import TableAsync as TableType else: - from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient - - CrossSync._Sync_Impl.add_mapping("GapicClient", BigtableClient) + from google.cloud.bigtable_v2.services.bigtable.client import ( + BigtableClient as GapicClientType, + ) + from google.cloud.bigtable.data._sync.client import Table as TableType -@CrossSync._Sync_Impl.add_mapping_decorator("_MutateRowsOperation") class _MutateRowsOperation: """ MutateRowsOperation manages the logic of sending a set of row mutations, @@ -63,8 +62,8 @@ class _MutateRowsOperation: def __init__( self, - gapic_client: "CrossSync.GapicClient", - table: "CrossSync.Table", + gapic_client: GapicClientType, + table: TableType, mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index 08e0dfbb2..552389302 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -14,7 +14,7 @@ # # This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from typing import Sequence +from typing import Sequence, TYPE_CHECKING from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB from google.cloud.bigtable_v2.types import RowSet as RowSetPB @@ -31,8 +31,13 @@ from google.api_core.retry import exponential_sleep_generator from google.cloud.bigtable.data._sync.cross_sync import CrossSync +if TYPE_CHECKING: + if CrossSync._Sync_Impl.is_async: + from google.cloud.bigtable.data._async.client import TableAsync as TableType + else: + from google.cloud.bigtable.data._sync.client import Table as TableType + -@CrossSync._Sync_Impl.add_mapping_decorator("_ReadRowsOperation") class _ReadRowsOperation: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream @@ -67,7 +72,7 @@ class _ReadRowsOperation: def __init__( self, query: ReadRowsQuery, - table: "CrossSync.Table", + table: TableType, operation_timeout: float, attempt_timeout: float, retryable_exceptions: Sequence[type[Exception]] = (), diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 4bbd5498b..95d3cc05e 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -61,46 +61,13 @@ if CrossSync._Sync_Impl.is_async: from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport, + PooledBigtableGrpcAsyncIOTransport as PooledTransportType, ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledChannel as AsyncPooledChannel, - ) - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient, - ) - from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync - from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - ) - - CrossSync._Sync_Impl.add_mapping("GapicClient", BigtableAsyncClient) - CrossSync._Sync_Impl.add_mapping( - "PooledTransport", PooledBigtableGrpcAsyncIOTransport - ) - CrossSync._Sync_Impl.add_mapping("PooledChannel", AsyncPooledChannel) - CrossSync._Sync_Impl.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) - CrossSync._Sync_Impl.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) - CrossSync._Sync_Impl.add_mapping("MutationsBatcher", MutationsBatcherAsync) else: from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledBigtableGrpcTransport, - ) - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledChannel, + PooledBigtableGrpcTransport as PooledTransportType, ) - from google.cloud.bigtable_v2.services.bigtable.client import BigtableClient - from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation - from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher - - CrossSync._Sync_Impl.add_mapping("GapicClient", BigtableClient) - CrossSync._Sync_Impl.add_mapping("PooledTransport", PooledBigtableGrpcTransport) - CrossSync._Sync_Impl.add_mapping("PooledChannel", PooledChannel) - CrossSync._Sync_Impl.add_mapping("_ReadRowsOperation", _ReadRowsOperation) - CrossSync._Sync_Impl.add_mapping("_MutateRowsOperation", _MutateRowsOperation) - CrossSync._Sync_Impl.add_mapping("MutationsBatcher", MutationsBatcher) if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery @@ -140,7 +107,7 @@ def __init__( RuntimeError: if called outside of an async context (no running event loop) ValueError: if pool_size is less than 1""" transport_str = f"bt-{self._client_version()}-{pool_size}" - transport = CrossSync._Sync_Impl.PooledTransport.with_fixed_size(pool_size) + transport = PooledTransportType.with_fixed_size(pool_size) BigtableClientMeta._transport_registry[transport_str] = transport client_info = DEFAULT_CLIENT_INFO client_info.client_library_version = self._client_version() @@ -168,9 +135,7 @@ def __init__( client_info=client_info, ) self._is_closed = CrossSync._Sync_Impl.Event() - self.transport = cast( - CrossSync._Sync_Impl.PooledTransport, self._gapic_client.transport - ) + self.transport = cast(PooledTransportType, self._gapic_client.transport) self._active_instances: Set[_WarmedInstanceKey] = set() self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 232ca6e12..e9264c316 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -36,9 +36,9 @@ from google.cloud.bigtable.data.mutations import RowMutationEntry if CrossSync._Sync_Impl.is_async: - pass + from google.cloud.bigtable.data._async.client import TableAsync as TableType else: - from google.cloud.bigtable.data._sync.client import Table + from google.cloud.bigtable.data._sync.client import Table as TableType @CrossSync._Sync_Impl.add_mapping_decorator("_FlowControl") @@ -185,7 +185,7 @@ class MutationsBatcher: def __init__( self, - table: Table, + table: TableType, *, flush_interval: float | None = 5, flush_limit_mutation_count: int | None = 1000, diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index b67f77298..807a4914f 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -1162,7 +1162,7 @@ def test_customizable_retryable_errors( ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), ("row_exists", (b"row_key",), "read_rows"), ("sample_row_keys", (), "sample_row_keys"), - ("mutate_row", (b"row_key", [mock.Mock()]), "mutate_row"), + ("mutate_row", (b"row_key", [mutations.DeleteAllFromRow()]), "mutate_row"), ( "bulk_mutate_rows", ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), @@ -1171,7 +1171,7 @@ def test_customizable_retryable_errors( ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), ( "read_modify_write_row", - (b"row_key", mock.Mock()), + (b"row_key", IncrementRule("f", "q")), "read_modify_write_row", ), ], From e5d6a29ec57765091f0834924ab895ca59547eab Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 13 Aug 2024 16:46:15 -0700 Subject: [PATCH 221/360] added annotations for execute_query --- google/cloud/bigtable/data/__init__.py | 5 + google/cloud/bigtable/data/_async/client.py | 15 +- .../_async/execute_query_iterator.py | 64 +++-- .../data/execute_query/_async/_testing.py | 36 --- .../_async/test_query_iterator.py | 252 ++++++++++-------- 5 files changed, 192 insertions(+), 180 deletions(-) delete mode 100644 tests/unit/data/execute_query/_async/_testing.py diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index cc09418b5..e176ec765 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -57,6 +57,10 @@ ) from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync +from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( + ExecuteQueryIteratorAsync, +) + from google.cloud.bigtable.data._sync.cross_sync import CrossSync CrossSync.add_mapping("GapicClient", BigtableAsyncClient) @@ -64,6 +68,7 @@ CrossSync.add_mapping("PooledTransport", PooledBigtableGrpcAsyncIOTransport) CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) +CrossSync.add_mapping("ExecuteQueryIterator", ExecuteQueryIteratorAsync) __version__: str = package_version.__version__ diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index f82393fa5..340a948dd 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -34,9 +34,6 @@ from functools import partial from grpc import Channel -from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( - ExecuteQueryIteratorAsync, -) from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType from google.cloud.bigtable.data.execute_query.metadata import SqlType from google.cloud.bigtable.data.execute_query._parameters_formatting import ( @@ -91,12 +88,19 @@ from google.cloud.bigtable.data._async.mutations_batcher import ( MutationsBatcherAsync, ) + from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( + ExecuteQueryIteratorAsync, + ) else: from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport as PooledTransportType # type: ignore from google.cloud.bigtable.data._sync.mutations_batcher import ( # noqa: F401 MutationsBatcher, ) + from google.cloud.bigtable.data.execute_query._sync.execute_query_iterator import ( # noqa: F401 + ExecuteQueryIterator, + ) + if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples @@ -371,7 +375,7 @@ async def _manage_channel( next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = next_refresh - (time.monotonic() - start_timestamp) - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) + @CrossSync.convert(replace_symbols={"TableAsync": "Table", "ExecuteQueryIteratorAsync": "ExecuteQueryIterator"}) async def _register_instance( self, instance_id: str, owner: TableAsync | ExecuteQueryIteratorAsync ) -> None: @@ -406,7 +410,7 @@ async def _register_instance( # refresh tasks aren't active. start them as background tasks self._start_background_channel_refresh() - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) + @CrossSync.convert(replace_symbols={"TableAsync": "Table", "ExecuteQueryIteratorAsync": "ExecuteQueryIterator"}) async def _remove_instance_registration( self, instance_id: str, owner: TableAsync | ExecuteQueryIteratorAsync ) -> bool: @@ -479,6 +483,7 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs """ return TableAsync(self, instance_id, table_id, *args, **kwargs) + @CrossSync.convert(replace_symbols={"ExecuteQueryIteratorAsync": "ExecuteQueryIterator"}) async def execute_query( self, query: str, diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index 32081939b..9d2832d10 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -44,10 +44,16 @@ ExecuteQueryRequest as ExecuteQueryRequestPB, ) +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + if TYPE_CHECKING: - from google.cloud.bigtable.data import BigtableDataClientAsync + if CrossSync.is_async: + from google.cloud.bigtable.data import BigtableDataClientAsync as DataClientType + else: + from google.cloud.bigtable.data import BigtableDataClient as DataClientType +@CrossSync.export_sync(path="google.cloud.bigtable.data.execute_query._sync.execute_query_iterator.ExecuteQueryIterator") class ExecuteQueryIteratorAsync: """ ExecuteQueryIteratorAsync handles collecting streaming responses from the @@ -77,7 +83,7 @@ class ExecuteQueryIteratorAsync: def __init__( self, - client: BigtableDataClientAsync, + client: DataClientType, instance_id: str, app_profile_id: Optional[str], request_body: Dict[str, Any], @@ -99,7 +105,7 @@ def __init__( self._attempt_timeout_gen = _attempt_timeout_generator( attempt_timeout, operation_timeout ) - self._async_stream = retries.retry_target_stream_async( + self._stream = CrossSync.retry_target_stream( self._make_request_with_resume_token, retries.if_exception_type(*retryable_excs), retries.exponential_sleep_generator(0.01, 60, multiplier=2), @@ -109,8 +115,11 @@ def __init__( self._req_metadata = req_metadata try: - self._register_instance_task = asyncio.create_task( - self._client._register_instance(instance_id, self) + self._register_instance_task = CrossSync.create_task( + self._client._register_instance, + instance_id, + self, + sync_executor=self._client._executor, ) except RuntimeError as e: raise RuntimeError( @@ -132,6 +141,7 @@ def table_name(self) -> Optional[str]: """Returns the table_name of the iterator.""" return self._table_name + @CrossSync.convert async def _make_request_with_resume_token(self): """ perfoms the rpc call using the correct resume token. @@ -143,30 +153,34 @@ async def _make_request_with_resume_token(self): "resume_token": resume_token, } ) - return await self._client._gapic_client.execute_query( - request, - timeout=next(self._attempt_timeout_gen), - metadata=self._req_metadata, - retry=None, + return CrossSync.rm_aio( + await self._client._gapic_client.execute_query( + request, + timeout=next(self._attempt_timeout_gen), + metadata=self._req_metadata, + retry=None, + ) ) - async def _await_metadata(self) -> None: + @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) + async def _fetch_metadata(self) -> None: """ If called before the first response was recieved, the first response - is awaited as part of this call. + is retrieved as part of this call. """ if self._byte_cursor.metadata is None: - metadata_msg = await self._async_stream.__anext__() + metadata_msg = CrossSync.rm_aio(await self._stream.__anext__()) self._byte_cursor.consume_metadata(metadata_msg) - async def _next_impl(self) -> AsyncIterator[QueryResultRow]: + @CrossSync.convert + async def _next_impl(self) -> CrossSync.Iterator[QueryResultRow]: """ Generator wrapping the response stream which parses the stream results and returns full `QueryResultRow`s. """ - await self._await_metadata() + CrossSync.rm_aio(await self._fetch_metadata()) - async for response in self._async_stream: + async for response in CrossSync.rm_aio(self._stream): try: bytes_to_parse = self._byte_cursor.consume(response) if bytes_to_parse is None: @@ -183,16 +197,19 @@ async def _next_impl(self) -> AsyncIterator[QueryResultRow]: for result in results: yield result - await self.close() + CrossSync.rm_aio(await self.close()) + @CrossSync.convert(sync_name="__next__", replace_symbols={"__anext__": "__next__"}) async def __anext__(self) -> QueryResultRow: if self._is_closed: - raise StopAsyncIteration - return await self._result_generator.__anext__() + raise CrossSync.StopIteration + return CrossSync.rm_aio(await self._result_generator.__anext__()) + @CrossSync.convert(sync_name="__iter__") def __aiter__(self): return self + @CrossSync.convert async def metadata(self) -> Optional[Metadata]: """ Returns query metadata from the server or None if the iterator was @@ -203,11 +220,12 @@ async def metadata(self) -> Optional[Metadata]: # Metadata should be present in the first response in a stream. if self._byte_cursor.metadata is None: try: - await self._await_metadata() - except StopIteration: + CrossSync.rm_aio(await self._fetch_metadata()) + except CrossSync.StopIteration: return None return self._byte_cursor.metadata + @CrossSync.convert async def close(self) -> None: """ Cancel all background tasks. Should be called all rows were processed. @@ -217,4 +235,6 @@ async def close(self) -> None: self._is_closed = True if self._register_instance_task is not None: self._register_instance_task.cancel() - await self._client._remove_instance_registration(self._instance_id, self) + CrossSync.rm_aio( + await self._client._remove_instance_registration(self._instance_id, self) + ) diff --git a/tests/unit/data/execute_query/_async/_testing.py b/tests/unit/data/execute_query/_async/_testing.py deleted file mode 100644 index 5a7acbdd9..000000000 --- a/tests/unit/data/execute_query/_async/_testing.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# flake8: noqa -from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes - - -try: - # async mock for python3.7-10 - from unittest.mock import Mock - from asyncio import coroutine - - def async_mock(return_value=None): - coro = Mock(name="CoroutineResult") - corofunc = Mock(name="CoroutineFunction", side_effect=coroutine(coro)) - corofunc.coro = coro - corofunc.coro.return_value = return_value - return corofunc - -except ImportError: - # async mock for python3.11 or later - from unittest.mock import AsyncMock - - def async_mock(return_value=None): - return AsyncMock(return_value=return_value) diff --git a/tests/unit/data/execute_query/_async/test_query_iterator.py b/tests/unit/data/execute_query/_async/test_query_iterator.py index 5c577ed74..08a7e0711 100644 --- a/tests/unit/data/execute_query/_async/test_query_iterator.py +++ b/tests/unit/data/execute_query/_async/test_query_iterator.py @@ -13,144 +13,162 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -from unittest.mock import Mock -from mock import patch import pytest +import concurrent.futures from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( ExecuteQueryIteratorAsync, ) from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse -from ._testing import TYPE_INT, proto_rows_bytes, split_bytes_into_chunks, async_mock +from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes +from google.cloud.bigtable.data._sync.cross_sync import CrossSync -class MockIteratorAsync: +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock +except ImportError: # pragma: NO COVER + import mock # type: ignore + + +@CrossSync.export_sync(path="tests.unit.data.execute_query._sync.test_query_iterator.MockIterator") +class MockIterator: def __init__(self, values, delay=None): self._values = values self.idx = 0 self._delay = delay + @CrossSync.convert(sync_name="__iter__") def __aiter__(self): return self + @CrossSync.convert(sync_name="__next__") async def __anext__(self): if self.idx >= len(self._values): - raise StopAsyncIteration + raise CrossSync.StopIteration if self._delay is not None: - await asyncio.sleep(self._delay) + CrossSync.rm_aio(await CrossSync.sleep(self._delay)) value = self._values[self.idx] self.idx += 1 return value -@pytest.fixture -def proto_byte_stream(): - proto_rows = [ - proto_rows_bytes({"int_value": 1}, {"int_value": 2}), - proto_rows_bytes({"int_value": 3}, {"int_value": 4}), - proto_rows_bytes({"int_value": 5}, {"int_value": 6}), - ] - - messages = [ - *split_bytes_into_chunks(proto_rows[0], num_chunks=2), - *split_bytes_into_chunks(proto_rows[1], num_chunks=3), - proto_rows[2], - ] - - stream = [ - ExecuteQueryResponse( - metadata={ - "proto_schema": { - "columns": [ - {"name": "test1", "type_": TYPE_INT}, - {"name": "test2", "type_": TYPE_INT}, - ] +@CrossSync.export_sync(path="tests.unit.data.execute_query._sync.test_query_iterator.TestQueryIterator") +class TestQueryIteratorAsync: + + @staticmethod + def _target_class(): + return CrossSync.ExecuteQueryIterator + + def _make_one(self, *args, **kwargs): + return self._target_class()(*args, **kwargs) + + @pytest.fixture + def proto_byte_stream(self): + proto_rows = [ + proto_rows_bytes({"int_value": 1}, {"int_value": 2}), + proto_rows_bytes({"int_value": 3}, {"int_value": 4}), + proto_rows_bytes({"int_value": 5}, {"int_value": 6}), + ] + + messages = [ + *split_bytes_into_chunks(proto_rows[0], num_chunks=2), + *split_bytes_into_chunks(proto_rows[1], num_chunks=3), + proto_rows[2], + ] + + stream = [ + ExecuteQueryResponse( + metadata={ + "proto_schema": { + "columns": [ + {"name": "test1", "type_": TYPE_INT}, + {"name": "test2", "type_": TYPE_INT}, + ] + } + } + ), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[0]}}), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[1]}, + "resume_token": b"token1", + } + ), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[2]}}), + ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[3]}}), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[4]}, + "resume_token": b"token2", + } + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[5]}, + "resume_token": b"token3", } - } - ), - ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[0]}}), - ExecuteQueryResponse( - results={ - "proto_rows_batch": {"batch_data": messages[1]}, - "resume_token": b"token1", - } - ), - ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[2]}}), - ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[3]}}), - ExecuteQueryResponse( - results={ - "proto_rows_batch": {"batch_data": messages[4]}, - "resume_token": b"token2", - } - ), - ExecuteQueryResponse( - results={ - "proto_rows_batch": {"batch_data": messages[5]}, - "resume_token": b"token3", - } - ), - ] - return stream - - -@pytest.mark.asyncio -async def test_iterator(proto_byte_stream): - client_mock = Mock() - - client_mock._register_instance = async_mock() - client_mock._remove_instance_registration = async_mock() - mock_async_iterator = MockIteratorAsync(proto_byte_stream) - iterator = None - - with patch( - "google.api_core.retry.retry_target_stream_async", - return_value=mock_async_iterator, - ): - iterator = ExecuteQueryIteratorAsync( - client=client_mock, - instance_id="test-instance", - app_profile_id="test_profile", - request_body={}, - attempt_timeout=10, - operation_timeout=10, - req_metadata=(), - retryable_excs=[], - ) - result = [] - async for value in iterator: - result.append(tuple(value)) - assert result == [(1, 2), (3, 4), (5, 6)] - - assert iterator.is_closed - client_mock._register_instance.assert_called_once() - client_mock._remove_instance_registration.assert_called_once() - - assert mock_async_iterator.idx == len(proto_byte_stream) - - -@pytest.mark.asyncio -async def test_iterator_awaits_metadata(proto_byte_stream): - client_mock = Mock() - - client_mock._register_instance = async_mock() - client_mock._remove_instance_registration = async_mock() - mock_async_iterator = MockIteratorAsync(proto_byte_stream) - iterator = None - with patch( - "google.api_core.retry.retry_target_stream_async", - return_value=mock_async_iterator, - ): - iterator = ExecuteQueryIteratorAsync( - client=client_mock, - instance_id="test-instance", - app_profile_id="test_profile", - request_body={}, - attempt_timeout=10, - operation_timeout=10, - req_metadata=(), - retryable_excs=[], - ) - - await iterator.metadata() - - assert mock_async_iterator.idx == 1 + ), + ] + return stream + + + @CrossSync.pytest + async def test_iterator(self, proto_byte_stream): + client_mock = mock.Mock() + + client_mock._register_instance = CrossSync.Mock() + client_mock._remove_instance_registration = CrossSync.Mock() + client_mock._executor = concurrent.futures.ThreadPoolExecutor() + mock_async_iterator = MockIterator(proto_byte_stream) + iterator = None + + with mock.patch.object( + CrossSync, "retry_target_stream", return_value=mock_async_iterator, + ): + iterator = self._make_one( + client=client_mock, + instance_id="test-instance", + app_profile_id="test_profile", + request_body={}, + attempt_timeout=10, + operation_timeout=10, + req_metadata=(), + retryable_excs=[], + ) + result = [] + async for value in iterator: + result.append(tuple(value)) + assert result == [(1, 2), (3, 4), (5, 6)] + + assert iterator.is_closed + client_mock._register_instance.assert_called_once() + client_mock._remove_instance_registration.assert_called_once() + + assert mock_async_iterator.idx == len(proto_byte_stream) + + + @CrossSync.pytest + async def test_iterator_awaits_metadata(self, proto_byte_stream): + client_mock = mock.Mock() + + client_mock._register_instance = CrossSync.Mock() + client_mock._remove_instance_registration = CrossSync.Mock() + mock_async_iterator = MockIterator(proto_byte_stream) + iterator = None + with mock.patch.object( + CrossSync, "retry_target_stream", return_value=mock_async_iterator, + ): + iterator = self._make_one( + client=client_mock, + instance_id="test-instance", + app_profile_id="test_profile", + request_body={}, + attempt_timeout=10, + operation_timeout=10, + req_metadata=(), + retryable_excs=[], + ) + + await iterator.metadata() + + assert mock_async_iterator.idx == 1 From ebf126a2fb56e382a27c5fd8caefee79d4cd731e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 13 Aug 2024 16:46:38 -0700 Subject: [PATCH 222/360] replace all instances in docstrings --- .cross_sync/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 76439b80e..a5c4eeb6a 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -75,7 +75,7 @@ def visit_FunctionDef(self, node): node.body[0].value, ast.Str ): for key_word, replacement in self.replacements.items(): - docstring = docstring.replace(f" {key_word} ", f" {replacement} ") + docstring = docstring.replace(key_word, replacement) node.body[0].value.s = docstring return self.generic_visit(node) From 2c926df37e6818a093311d71690e5c4950b12e01 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 13 Aug 2024 16:49:28 -0700 Subject: [PATCH 223/360] regenerated files --- .../cloud/bigtable/data/_sync/_mutate_rows.py | 4 +- .../cloud/bigtable/data/_sync/_read_rows.py | 4 +- google/cloud/bigtable/data/_sync/client.py | 126 ++++++++++- .../_sync/execute_query_iterator.py | 196 ++++++++++++++++++ tests/unit/data/_sync/test__read_rows.py | 2 +- .../unit/data/execute_query/_sync/__init__.py | 13 ++ .../_sync/test_query_iterator.py | 161 ++++++++++++++ 7 files changed, 494 insertions(+), 12 deletions(-) create mode 100644 google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py create mode 100644 tests/unit/data/execute_query/_sync/__init__.py create mode 100644 tests/unit/data/execute_query/_sync/test_query_iterator.py diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index ee3b81fdc..f36557743 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -74,7 +74,9 @@ def __init__( raise ValueError( f"mutate_rows requests can contain at most {_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across all entries. Found {total_mutations}." ) - metadata = _make_metadata(table.table_name, table.app_profile_id) + metadata = _make_metadata( + table.table_name, table.app_profile_id, instance_name=None + ) self._gapic_fn = functools.partial( gapic_client.mutate_rows, table_name=table.table_name, diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index 552389302..b5fa35479 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -91,7 +91,9 @@ def __init__( self.request = query._to_pb(table) self.table = table self._predicate = retries.if_exception_type(*retryable_exceptions) - self._metadata = _make_metadata(table.table_name, table.app_profile_id) + self._metadata = _make_metadata( + table.table_name, table.app_profile_id, instance_name=None + ) self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 95d3cc05e..8274e922d 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -22,6 +22,11 @@ import concurrent.futures from functools import partial from grpc import Channel +from google.cloud.bigtable.data.execute_query.values import ExecuteQueryValueType +from google.cloud.bigtable.data.execute_query.metadata import SqlType +from google.cloud.bigtable.data.execute_query._parameters_formatting import ( + _format_execute_query_params, +) from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta from google.cloud.bigtable_v2.services.bigtable.transports.base import ( DEFAULT_CLIENT_INFO, @@ -47,6 +52,7 @@ from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory from google.cloud.bigtable.data._helpers import _validate_timeouts +from google.cloud.bigtable.data._helpers import _get_error_type from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import _attempt_timeout_generator @@ -68,6 +74,9 @@ PooledBigtableGrpcTransport as PooledTransportType, ) from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher + from google.cloud.bigtable.data.execute_query._sync.execute_query_iterator import ( + ExecuteQueryIterator, + ) if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery @@ -293,7 +302,9 @@ def _manage_channel( next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = next_refresh - (time.monotonic() - start_timestamp) - def _register_instance(self, instance_id: str, owner: Table) -> None: + def _register_instance( + self, instance_id: str, owner: Table | ExecuteQueryIterator + ) -> None: """Registers an instance with the client, and warms the channel pool for the instance The client will periodically refresh grpc channel pool used to make @@ -318,7 +329,9 @@ def _register_instance(self, instance_id: str, owner: Table) -> None: else: self._start_background_channel_refresh() - def _remove_instance_registration(self, instance_id: str, owner: Table) -> bool: + def _remove_instance_registration( + self, instance_id: str, owner: Table | ExecuteQueryIterator + ) -> bool: """Removes an instance from the client's registered instances, to prevent warming new channels for the instance @@ -378,12 +391,99 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> Table: encountered during all other operations. Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) Returns: - TableAsync: a table instance for making data API requests + Table: a table instance for making data API requests Raises: RuntimeError: if called outside of an async context (no running event loop) """ return Table(self, instance_id, table_id, *args, **kwargs) + def execute_query( + self, + query: str, + instance_id: str, + *, + parameters: dict[str, ExecuteQueryValueType] | None = None, + parameter_types: dict[str, SqlType.Type] | None = None, + app_profile_id: str | None = None, + operation_timeout: float = 600, + attempt_timeout: float | None = 20, + retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + Aborted, + ), + ) -> "ExecuteQueryIterator": + """Executes an SQL query on an instance. + Returns an iterator to asynchronously stream back columns from selected rows. + + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. + + Args: + query: Query to be run on Bigtable instance. The query can use ``@param`` + placeholders to use parameter interpolation on the server. Values for all + parameters should be provided in ``parameters``. Types of parameters are + inferred but should be provided in ``parameter_types`` if the inference is + not possible (i.e. when value can be None, an empty list or an empty dict). + instance_id: The Bigtable instance ID to perform the query on. + instance_id is combined with the client's project to fully + specify the instance. + parameters: Dictionary with values for all parameters used in the ``query``. + parameter_types: Dictionary with types of parameters used in the ``query``. + Required to contain entries only for parameters whose type cannot be + detected automatically (i.e. the value can be None, an empty list or + an empty dict). + app_profile_id: The app profile to associate with requests. + https://cloud.google.com/bigtable/docs/app-profiles + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to 600 seconds. + attempt_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the 20 seconds. + If None, defaults to operation_timeout. + retryable_errors: a list of errors that will be retried if encountered. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + Returns: + ExecuteQueryIterator: an asynchronous iterator that yields rows returned by the query + Raises: + google.api_core.exceptions.DeadlineExceeded: raised after operation timeout + will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + from any retries that failed + google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error + """ + warnings.warn( + "ExecuteQuery is in preview and may change in the future.", + category=RuntimeWarning, + ) + retryable_excs = [_get_error_type(e) for e in retryable_errors] + pb_params = _format_execute_query_params(parameters, parameter_types) + instance_name = self._gapic_client.instance_path(self.project, instance_id) + request_body = { + "instance_name": instance_name, + "app_profile_id": app_profile_id, + "query": query, + "params": pb_params, + "proto_format": {}, + } + app_profile_id_for_metadata = app_profile_id or "" + req_metadata = _make_metadata( + table_name=None, + app_profile_id=app_profile_id_for_metadata, + instance_name=instance_name, + ) + return ExecuteQueryIterator( + self, + instance_id, + app_profile_id, + request_body, + attempt_timeout, + operation_timeout, + req_metadata, + retryable_excs, + ) + def __enter__(self): self._start_background_channel_refresh() return self @@ -543,7 +643,7 @@ def read_rows_stream( retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_read_rows_retryable_errors Returns: - AsyncIterable[Row]: an asynchronous iterator that yields rows returned by the query + Iterable[Row]: an asynchronous iterator that yields rows returned by the query Raises: google.api_core.exceptions.DeadlineExceeded: raised after operation timeout will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions @@ -838,7 +938,9 @@ def sample_row_keys( retryable_excs = _get_retryable_errors(retryable_errors, self) predicate = retries.if_exception_type(*retryable_excs) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - metadata = _make_metadata(self.table_name, self.app_profile_id) + metadata = _make_metadata( + self.table_name, self.app_profile_id, instance_name=None + ) def execute_rpc(): results = self.client._gapic_client.sample_row_keys( @@ -892,7 +994,7 @@ def mutations_batcher( batch_retryable_errors: a list of errors that will be retried if encountered. Defaults to the Table's default_mutate_rows_retryable_errors. Returns: - MutationsBatcherAsync: a MutationsBatcher context manager that can batch requests + MutationsBatcher: a MutationsBatcher context manager that can batch requests """ return CrossSync._Sync_Impl.MutationsBatcher( self, @@ -965,7 +1067,9 @@ def mutate_row( table_name=self.table_name, app_profile_id=self.app_profile_id, timeout=attempt_timeout, - metadata=_make_metadata(self.table_name, self.app_profile_id), + metadata=_make_metadata( + self.table_name, self.app_profile_id, instance_name=None + ), retry=None, ) return CrossSync._Sync_Impl.retry_target( @@ -1078,7 +1182,9 @@ def check_and_mutate_row( ): false_case_mutations = [false_case_mutations] false_case_list = [m._to_pb() for m in false_case_mutations or []] - metadata = _make_metadata(self.table_name, self.app_profile_id) + metadata = _make_metadata( + self.table_name, self.app_profile_id, instance_name=None + ) result = self.client._gapic_client.check_and_mutate_row( true_mutations=true_case_list, false_mutations=false_case_list, @@ -1127,7 +1233,9 @@ def read_modify_write_row( rules = [rules] if not rules: raise ValueError("rules must contain at least one item") - metadata = _make_metadata(self.table_name, self.app_profile_id) + metadata = _make_metadata( + self.table_name, self.app_profile_id, instance_name=None + ) result = self.client._gapic_client.read_modify_write_row( rules=[rule._to_pb() for rule in rules], row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, diff --git a/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py new file mode 100644 index 000000000..9afb59bef --- /dev/null +++ b/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py @@ -0,0 +1,196 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by CrossSync. Do not edit manually. +from __future__ import annotations +from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING +from google.api_core import retry as retries +from google.cloud.bigtable.data.execute_query._byte_cursor import _ByteCursor +from google.cloud.bigtable.data._helpers import ( + _attempt_timeout_generator, + _retry_exception_factory, +) +from google.cloud.bigtable.data.exceptions import InvalidExecuteQueryResponse +from google.cloud.bigtable.data.execute_query.values import QueryResultRow +from google.cloud.bigtable.data.execute_query.metadata import Metadata, ProtoMetadata +from google.cloud.bigtable.data.execute_query._reader import ( + _QueryResultRowReader, + _Reader, +) +from google.cloud.bigtable_v2.types.bigtable import ( + ExecuteQueryRequest as ExecuteQueryRequestPB, +) +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +if TYPE_CHECKING: + if CrossSync._Sync_Impl.is_async: + from google.cloud.bigtable.data import BigtableDataClientAsync as DataClientType + else: + from google.cloud.bigtable.data import BigtableDataClient as DataClientType + + +class ExecuteQueryIterator: + """ + ExecuteQueryIteratorAsync handles collecting streaming responses from the + ExecuteQuery RPC and parsing them to QueryResultRows. + + ExecuteQueryIteratorAsync implements Asynchronous Iterator interface and can + be used with "async for" syntax. It is also a context manager. + + It is **not thread-safe**. It should not be used by multiple asyncio Tasks. + + Args: + client: bigtable client + instance_id: id of the instance on which the query is executed + request_body: dict representing the body of the ExecuteQueryRequest + attempt_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to 600 seconds. + operation_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the 20 seconds. If None, defaults to operation_timeout. + req_metadata: metadata used while sending the gRPC request + retryable_excs: a list of errors that will be retried if encountered. + Raises: + RuntimeError: if the instance is not created within an async event loop context. + """ + + def __init__( + self, + client: DataClientType, + instance_id: str, + app_profile_id: Optional[str], + request_body: Dict[str, Any], + attempt_timeout: float | None, + operation_timeout: float, + req_metadata: Sequence[Tuple[str, str]], + retryable_excs: List[type[Exception]], + ) -> None: + self._table_name = None + self._app_profile_id = app_profile_id + self._client = client + self._instance_id = instance_id + self._byte_cursor = _ByteCursor[ProtoMetadata]() + self._reader: _Reader[QueryResultRow] = _QueryResultRowReader(self._byte_cursor) + self._result_generator = self._next_impl() + self._register_instance_task = None + self._is_closed = False + self._request_body = request_body + self._attempt_timeout_gen = _attempt_timeout_generator( + attempt_timeout, operation_timeout + ) + self._stream = CrossSync._Sync_Impl.retry_target_stream( + self._make_request_with_resume_token, + retries.if_exception_type(*retryable_excs), + retries.exponential_sleep_generator(0.01, 60, multiplier=2), + operation_timeout, + exception_factory=_retry_exception_factory, + ) + self._req_metadata = req_metadata + try: + self._register_instance_task = CrossSync._Sync_Impl.create_task( + self._client._register_instance, + instance_id, + self, + sync_executor=self._client._executor, + ) + except RuntimeError as e: + raise RuntimeError( + f"{self.__class__.__name__} must be created within an async event loop context." + ) from e + + @property + def is_closed(self) -> bool: + """Returns True if the iterator is closed, False otherwise.""" + return self._is_closed + + @property + def app_profile_id(self) -> Optional[str]: + """Returns the app_profile_id of the iterator.""" + return self._app_profile_id + + @property + def table_name(self) -> Optional[str]: + """Returns the table_name of the iterator.""" + return self._table_name + + def _make_request_with_resume_token(self): + """perfoms the rpc call using the correct resume token.""" + resume_token = self._byte_cursor.prepare_for_new_request() + request = ExecuteQueryRequestPB( + {**self._request_body, "resume_token": resume_token} + ) + return self._client._gapic_client.execute_query( + request, + timeout=next(self._attempt_timeout_gen), + metadata=self._req_metadata, + retry=None, + ) + + def _fetch_metadata(self) -> None: + """If called before the first response was recieved, the first response + is retrieved as part of this call.""" + if self._byte_cursor.metadata is None: + metadata_msg = self._stream.__next__() + self._byte_cursor.consume_metadata(metadata_msg) + + def _next_impl(self) -> CrossSync._Sync_Impl.Iterator[QueryResultRow]: + """Generator wrapping the response stream which parses the stream results + and returns full `QueryResultRow`s.""" + self._fetch_metadata() + for response in self._stream: + try: + bytes_to_parse = self._byte_cursor.consume(response) + if bytes_to_parse is None: + continue + results = self._reader.consume(bytes_to_parse) + if results is None: + continue + except ValueError as e: + raise InvalidExecuteQueryResponse( + "Invalid ExecuteQuery response received" + ) from e + for result in results: + yield result + self.close() + + def __next__(self) -> QueryResultRow: + if self._is_closed: + raise CrossSync._Sync_Impl.StopIteration + return self._result_generator.__next__() + + def __iter__(self): + return self + + def metadata(self) -> Optional[Metadata]: + """Returns query metadata from the server or None if the iterator was + explicitly closed.""" + if self._is_closed: + return None + if self._byte_cursor.metadata is None: + try: + self._fetch_metadata() + except CrossSync._Sync_Impl.StopIteration: + return None + return self._byte_cursor.metadata + + def close(self) -> None: + """Cancel all background tasks. Should be called all rows were processed.""" + if self._is_closed: + return + self._is_closed = True + if self._register_instance_task is not None: + self._register_instance_task.cancel() + self._client._remove_instance_registration(self._instance_id, self) diff --git a/tests/unit/data/_sync/test__read_rows.py b/tests/unit/data/_sync/test__read_rows.py index 015f96d98..556bdf0ec 100644 --- a/tests/unit/data/_sync/test__read_rows.py +++ b/tests/unit/data/_sync/test__read_rows.py @@ -300,7 +300,7 @@ def mock_stream(): assert "emit count exceeds row limit" in str(e.value) def test_close(self): - """should be able to close a stream safely with aclose. + """should be able to close a stream safely with close. Closed generators should raise StopAsyncIteration on next yield""" def mock_stream(): diff --git a/tests/unit/data/execute_query/_sync/__init__.py b/tests/unit/data/execute_query/_sync/__init__.py new file mode 100644 index 000000000..6d5e14bcf --- /dev/null +++ b/tests/unit/data/execute_query/_sync/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/data/execute_query/_sync/test_query_iterator.py b/tests/unit/data/execute_query/_sync/test_query_iterator.py new file mode 100644 index 000000000..b3eded81e --- /dev/null +++ b/tests/unit/data/execute_query/_sync/test_query_iterator.py @@ -0,0 +1,161 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is automatically generated by CrossSync. Do not edit manually. +import pytest +import concurrent.futures +from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse +from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes +from google.cloud.bigtable.data._sync.cross_sync import CrossSync + +try: + from unittest import mock +except ImportError: + import mock + + +class MockIterator: + def __init__(self, values, delay=None): + self._values = values + self.idx = 0 + self._delay = delay + + def __iter__(self): + return self + + def __next__(self): + if self.idx >= len(self._values): + raise CrossSync._Sync_Impl.StopIteration + if self._delay is not None: + CrossSync._Sync_Impl.sleep(self._delay) + value = self._values[self.idx] + self.idx += 1 + return value + + +class TestQueryIterator: + @staticmethod + def _target_class(): + return CrossSync._Sync_Impl.ExecuteQueryIterator + + def _make_one(self, *args, **kwargs): + return self._target_class()(*args, **kwargs) + + @pytest.fixture + def proto_byte_stream(self): + proto_rows = [ + proto_rows_bytes({"int_value": 1}, {"int_value": 2}), + proto_rows_bytes({"int_value": 3}, {"int_value": 4}), + proto_rows_bytes({"int_value": 5}, {"int_value": 6}), + ] + messages = [ + *split_bytes_into_chunks(proto_rows[0], num_chunks=2), + *split_bytes_into_chunks(proto_rows[1], num_chunks=3), + proto_rows[2], + ] + stream = [ + ExecuteQueryResponse( + metadata={ + "proto_schema": { + "columns": [ + {"name": "test1", "type_": TYPE_INT}, + {"name": "test2", "type_": TYPE_INT}, + ] + } + } + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[0]}} + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[1]}, + "resume_token": b"token1", + } + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[2]}} + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[3]}} + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[4]}, + "resume_token": b"token2", + } + ), + ExecuteQueryResponse( + results={ + "proto_rows_batch": {"batch_data": messages[5]}, + "resume_token": b"token3", + } + ), + ] + return stream + + def test_iterator(self, proto_byte_stream): + client_mock = mock.Mock() + client_mock._register_instance = CrossSync._Sync_Impl.Mock() + client_mock._remove_instance_registration = CrossSync._Sync_Impl.Mock() + client_mock._executor = concurrent.futures.ThreadPoolExecutor() + mock_async_iterator = MockIterator(proto_byte_stream) + iterator = None + with mock.patch.object( + CrossSync._Sync_Impl, + "retry_target_stream", + return_value=mock_async_iterator, + ): + iterator = self._make_one( + client=client_mock, + instance_id="test-instance", + app_profile_id="test_profile", + request_body={}, + attempt_timeout=10, + operation_timeout=10, + req_metadata=(), + retryable_excs=[], + ) + result = [] + for value in iterator: + result.append(tuple(value)) + assert result == [(1, 2), (3, 4), (5, 6)] + assert iterator.is_closed + client_mock._register_instance.assert_called_once() + client_mock._remove_instance_registration.assert_called_once() + assert mock_async_iterator.idx == len(proto_byte_stream) + + def test_iterator_awaits_metadata(self, proto_byte_stream): + client_mock = mock.Mock() + client_mock._register_instance = CrossSync._Sync_Impl.Mock() + client_mock._remove_instance_registration = CrossSync._Sync_Impl.Mock() + mock_async_iterator = MockIterator(proto_byte_stream) + iterator = None + with mock.patch.object( + CrossSync._Sync_Impl, + "retry_target_stream", + return_value=mock_async_iterator, + ): + iterator = self._make_one( + client=client_mock, + instance_id="test-instance", + app_profile_id="test_profile", + request_body={}, + attempt_timeout=10, + operation_timeout=10, + req_metadata=(), + retryable_excs=[], + ) + iterator.metadata() + assert mock_async_iterator.idx == 1 From 792abd9e38da315b59cd770bf8d8d8a49246e6c9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 30 Aug 2024 16:36:22 -0700 Subject: [PATCH 224/360] added next to cross_sync --- google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py index 04ac79c73..1219c3b3d 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py @@ -105,6 +105,10 @@ class CrossSync(metaclass=MappingMeta): PytestFixture.decorator ) # decorate test methods to run with pytest fixture + @classmethod + def next(cls, iterable): + return iterable.__anext__() + @classmethod def Mock(cls, *args, **kwargs): """ @@ -230,7 +234,7 @@ def verify_async_event_loop() -> None: asyncio.get_running_loop() @staticmethod - def rm_aio(statement: Any) -> Any: + def rm_aio(statement: T) -> T: """ Used to annotate regions of the code containing async keywords to strip @@ -247,6 +251,7 @@ class _Sync_Impl(metaclass=MappingMeta): is_async = False sleep = time.sleep + next = next retry_target = retries.retry_target retry_target_stream = retries.retry_target_stream Retry = retries.Retry From 5dc231cfdf5022861276455d09f1bd1e6e9bb3f7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 30 Aug 2024 16:39:30 -0700 Subject: [PATCH 225/360] moved execute_query e2e tests from system to unit --- tests/system/data/test_execute_query_async.py | 286 -------------- tests/system/data/test_execute_query_utils.py | 272 ------------- tests/unit/data/_async/test_client.py | 360 ++++++++++++++++++ 3 files changed, 360 insertions(+), 558 deletions(-) delete mode 100644 tests/system/data/test_execute_query_async.py delete mode 100644 tests/system/data/test_execute_query_utils.py diff --git a/tests/system/data/test_execute_query_async.py b/tests/system/data/test_execute_query_async.py deleted file mode 100644 index cbb492c67..000000000 --- a/tests/system/data/test_execute_query_async.py +++ /dev/null @@ -1,286 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -import os -from unittest import mock -from .test_execute_query_utils import ( - ChannelMockAsync, - response_with_metadata, - response_with_result, -) -from google.api_core import exceptions as core_exceptions -from google.cloud.bigtable.data import BigtableDataClientAsync -from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -TABLE_NAME = "TABLE_NAME" -INSTANCE_NAME = "INSTANCE_NAME" - - -class TestAsyncExecuteQuery: - @pytest.fixture() - def async_channel_mock(self): - with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): - yield ChannelMockAsync() - - @pytest.fixture() - def async_client(self, async_channel_mock): - with mock.patch.dict( - os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"} - ), mock.patch.object( - CrossSync, "PooledChannel", return_value=async_channel_mock - ): - yield BigtableDataClientAsync() - - @pytest.mark.asyncio - async def test_execute_query(self, async_client, async_channel_mock): - values = [ - response_with_metadata(), - response_with_result("test"), - response_with_result(8, resume_token=b"r1"), - response_with_result("test2"), - response_with_result(9, resume_token=b"r2"), - response_with_result("test3"), - response_with_result(None, resume_token=b"r3"), - ] - async_channel_mock.set_values(values) - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - results = [r async for r in result] - assert results[0]["a"] == "test" - assert results[0]["b"] == 8 - assert results[1]["a"] == "test2" - assert results[1]["b"] == 9 - assert results[2]["a"] == "test3" - assert results[2]["b"] is None - assert len(async_channel_mock.execute_query_calls) == 1 - - @pytest.mark.asyncio - async def test_execute_query_with_params(self, async_client, async_channel_mock): - values = [ - response_with_metadata(), - response_with_result("test2"), - response_with_result(9, resume_token=b"r2"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME} WHERE b=@b", - INSTANCE_NAME, - parameters={"b": 9}, - ) - results = [r async for r in result] - assert len(results) == 1 - assert results[0]["a"] == "test2" - assert results[0]["b"] == 9 - assert len(async_channel_mock.execute_query_calls) == 1 - - @pytest.mark.asyncio - async def test_execute_query_error_before_metadata( - self, async_client, async_channel_mock - ): - from google.api_core.exceptions import DeadlineExceeded - - values = [ - DeadlineExceeded(""), - response_with_metadata(), - response_with_result("test"), - response_with_result(8, resume_token=b"r1"), - response_with_result("test2"), - response_with_result(9, resume_token=b"r2"), - response_with_result("test3"), - response_with_result(None, resume_token=b"r3"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - results = [r async for r in result] - assert len(results) == 3 - assert len(async_channel_mock.execute_query_calls) == 2 - - @pytest.mark.asyncio - async def test_execute_query_error_after_metadata( - self, async_client, async_channel_mock - ): - from google.api_core.exceptions import DeadlineExceeded - - values = [ - response_with_metadata(), - DeadlineExceeded(""), - response_with_metadata(), - response_with_result("test"), - response_with_result(8, resume_token=b"r1"), - response_with_result("test2"), - response_with_result(9, resume_token=b"r2"), - response_with_result("test3"), - response_with_result(None, resume_token=b"r3"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - results = [r async for r in result] - assert len(results) == 3 - assert len(async_channel_mock.execute_query_calls) == 2 - assert async_channel_mock.resume_tokens == [] - - @pytest.mark.asyncio - async def test_execute_query_with_retries(self, async_client, async_channel_mock): - from google.api_core.exceptions import DeadlineExceeded - - values = [ - response_with_metadata(), - response_with_result("test"), - response_with_result(8, resume_token=b"r1"), - DeadlineExceeded(""), - response_with_result("test2"), - response_with_result(9, resume_token=b"r2"), - response_with_result("test3"), - DeadlineExceeded(""), - response_with_result("test3"), - response_with_result(None, resume_token=b"r3"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - results = [r async for r in result] - assert results[0]["a"] == "test" - assert results[0]["b"] == 8 - assert results[1]["a"] == "test2" - assert results[1]["b"] == 9 - assert results[2]["a"] == "test3" - assert results[2]["b"] is None - assert len(async_channel_mock.execute_query_calls) == 3 - assert async_channel_mock.resume_tokens == [b"r1", b"r2"] - - @pytest.mark.parametrize( - "exception", - [ - (core_exceptions.DeadlineExceeded("")), - (core_exceptions.Aborted("")), - (core_exceptions.ServiceUnavailable("")), - ], - ) - @pytest.mark.asyncio - async def test_execute_query_retryable_error( - self, async_client, async_channel_mock, exception - ): - values = [ - response_with_metadata(), - response_with_result("test", resume_token=b"t1"), - exception, - response_with_result(8, resume_token=b"t2"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - results = [r async for r in result] - assert len(results) == 1 - assert len(async_channel_mock.execute_query_calls) == 2 - assert async_channel_mock.resume_tokens == [b"t1"] - - @pytest.mark.asyncio - async def test_execute_query_retry_partial_row( - self, async_client, async_channel_mock - ): - values = [ - response_with_metadata(), - response_with_result("test", resume_token=b"t1"), - core_exceptions.DeadlineExceeded(""), - response_with_result(8, resume_token=b"t2"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - results = [r async for r in result] - assert results[0]["a"] == "test" - assert results[0]["b"] == 8 - assert len(async_channel_mock.execute_query_calls) == 2 - assert async_channel_mock.resume_tokens == [b"t1"] - - @pytest.mark.parametrize( - "ExceptionType", - [ - (core_exceptions.InvalidArgument), - (core_exceptions.FailedPrecondition), - (core_exceptions.PermissionDenied), - (core_exceptions.MethodNotImplemented), - (core_exceptions.Cancelled), - (core_exceptions.AlreadyExists), - (core_exceptions.OutOfRange), - (core_exceptions.DataLoss), - (core_exceptions.Unauthenticated), - (core_exceptions.NotFound), - (core_exceptions.ResourceExhausted), - (core_exceptions.Unknown), - (core_exceptions.InternalServerError), - ], - ) - @pytest.mark.asyncio - async def test_execute_query_non_retryable( - self, async_client, async_channel_mock, ExceptionType - ): - values = [ - response_with_metadata(), - response_with_result("test"), - response_with_result(8, resume_token=b"r1"), - ExceptionType(""), - response_with_result("test2"), - response_with_result(9, resume_token=b"r2"), - response_with_result("test3"), - response_with_result(None, resume_token=b"r3"), - ] - async_channel_mock.set_values(values) - - result = await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - r = await result.__anext__() - assert r["a"] == "test" - assert r["b"] == 8 - - with pytest.raises(ExceptionType): - r = await result.__anext__() - - assert len(async_channel_mock.execute_query_calls) == 1 - assert async_channel_mock.resume_tokens == [] - - @pytest.mark.asyncio - async def test_execute_query_metadata_received_multiple_times_detected( - self, async_client, async_channel_mock - ): - values = [ - response_with_metadata(), - response_with_metadata(), - ] - async_channel_mock.set_values(values) - - with pytest.raises(Exception, match="Invalid ExecuteQuery response received"): - [ - r - async for r in await async_client.execute_query( - f"SELECT a, b FROM {TABLE_NAME}", INSTANCE_NAME - ) - ] diff --git a/tests/system/data/test_execute_query_utils.py b/tests/system/data/test_execute_query_utils.py deleted file mode 100644 index 9e27b95f2..000000000 --- a/tests/system/data/test_execute_query_utils.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest import mock - -import google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio as pga -from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse -from google.cloud.bigtable_v2.types.data import ProtoRows, Value as PBValue -import grpc.aio - - -try: - # async mock for python3.7-10 - from asyncio import coroutine - - def async_mock(return_value=None): - coro = mock.Mock(name="CoroutineResult") - corofunc = mock.Mock(name="CoroutineFunction", side_effect=coroutine(coro)) - corofunc.coro = coro - corofunc.coro.return_value = return_value - return corofunc - -except ImportError: - # async mock for python3.11 or later - from unittest.mock import AsyncMock - - def async_mock(return_value=None): - return AsyncMock(return_value=return_value) - - -# ExecuteQueryResponse( -# metadata={ -# "proto_schema": { -# "columns": [ -# {"name": "test1", "type_": TYPE_INT}, -# {"name": "test2", "type_": TYPE_INT}, -# ] -# } -# } -# ), -# ExecuteQueryResponse( -# results={"proto_rows_batch": {"batch_data": messages[0]}} -# ), - - -def response_with_metadata(): - schema = {"a": "string_type", "b": "int64_type"} - return ExecuteQueryResponse( - { - "metadata": { - "proto_schema": { - "columns": [ - {"name": name, "type_": {_type: {}}} - for name, _type in schema.items() - ] - } - } - } - ) - - -def response_with_result(*args, resume_token=None): - if resume_token is None: - resume_token_dict = {} - else: - resume_token_dict = {"resume_token": resume_token} - - values = [] - for column_value in args: - if column_value is None: - pb_value = PBValue({}) - else: - pb_value = PBValue( - { - "int_value" - if isinstance(column_value, int) - else "string_value": column_value - } - ) - values.append(pb_value) - rows = ProtoRows(values=values) - - return ExecuteQueryResponse( - { - "results": { - "proto_rows_batch": { - "batch_data": ProtoRows.serialize(rows), - }, - **resume_token_dict, - } - } - ) - - -class ExecuteQueryStreamMock: - def __init__(self, parent): - self.parent = parent - self.iter = iter(self.parent.values) - - def __call__(self, *args, **kwargs): - request = args[0] - - self.parent.execute_query_calls.append(request) - if request.resume_token: - self.parent.resume_tokens.append(request.resume_token) - - def stream(): - for value in self.iter: - if isinstance(value, Exception): - raise value - else: - yield value - - return stream() - - -class ChannelMock: - def __init__(self): - self.execute_query_calls = [] - self.values = [] - self.resume_tokens = [] - - def set_values(self, values): - self.values = values - - def unary_unary(self, *args, **kwargs): - return mock.MagicMock() - - def unary_stream(self, *args, **kwargs): - if args[0] == "/google.bigtable.v2.Bigtable/ExecuteQuery": - return ExecuteQueryStreamMock(self) - return mock.MagicMock() - - -class ChannelMockAsync(pga.PooledChannel, mock.MagicMock): - def __init__(self, *args, **kwargs): - mock.MagicMock.__init__(self, *args, **kwargs) - self.execute_query_calls = [] - self.values = [] - self.resume_tokens = [] - self._iter = [] - - def get_async_get(self, *args, **kwargs): - return self.async_gen - - def set_values(self, values): - self.values = values - self._iter = iter(self.values) - - def unary_unary(self, *args, **kwargs): - return async_mock() - - def unary_stream(self, *args, **kwargs): - if args[0] == "/google.bigtable.v2.Bigtable/ExecuteQuery": - - async def async_gen(*args, **kwargs): - for value in self._iter: - yield value - - iter = async_gen() - - class UnaryStreamCallMock(grpc.aio.UnaryStreamCall): - def __aiter__(self): - async def _impl(*args, **kwargs): - try: - while True: - yield await self.read() - except StopAsyncIteration: - pass - - return _impl() - - async def read(self): - value = await iter.__anext__() - if isinstance(value, Exception): - raise value - return value - - def add_done_callback(*args, **kwargs): - pass - - def cancel(*args, **kwargs): - pass - - def cancelled(*args, **kwargs): - pass - - def code(*args, **kwargs): - pass - - def details(*args, **kwargs): - pass - - def done(*args, **kwargs): - pass - - def initial_metadata(*args, **kwargs): - pass - - def time_remaining(*args, **kwargs): - pass - - def trailing_metadata(*args, **kwargs): - pass - - async def wait_for_connection(*args, **kwargs): - return async_mock() - - class UnaryStreamMultiCallableMock(grpc.aio.UnaryStreamMultiCallable): - def __init__(self, parent): - self.parent = parent - - def __call__( - self, - request, - *, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None - ): - self.parent.execute_query_calls.append(request) - if request.resume_token: - self.parent.resume_tokens.append(request.resume_token) - return UnaryStreamCallMock() - - def add_done_callback(*args, **kwargs): - pass - - def cancel(*args, **kwargs): - pass - - def cancelled(*args, **kwargs): - pass - - def code(*args, **kwargs): - pass - - def details(*args, **kwargs): - pass - - def done(*args, **kwargs): - pass - - def initial_metadata(*args, **kwargs): - pass - - def time_remaining(*args, **kwargs): - pass - - def trailing_metadata(*args, **kwargs): - pass - - def wait_for_connection(*args, **kwargs): - pass - - # unary_stream should return https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc.aio.UnaryStreamMultiCallable - # PTAL https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc.aio.Channel.unary_stream - return UnaryStreamMultiCallableMock(self) - return async_mock() diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 7cbb89c14..bd5975d6e 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -3161,3 +3161,363 @@ async def test_read_modify_write_row_building(self): await table.read_modify_write_row("key", mock.Mock()) assert constructor_mock.call_count == 1 constructor_mock.assert_called_once_with(mock_response.row) + + +@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestExecuteQuery") +class TestExecuteQueryAsync: + TABLE_NAME = "TABLE_NAME" + INSTANCE_NAME = "INSTANCE_NAME" + + @CrossSync.convert + def _make_client(self, *args, **kwargs): + return CrossSync.TestBigtableDataClient._make_client(*args, **kwargs) + + @CrossSync.convert + def _make_gapic_stream(self, sample_list: list["ExecuteQueryResponse" | Exception]): + class MockStream: + def __init__(self, sample_list): + self.sample_list = sample_list + + def __aiter__(self): + return self + + def __iter__(self): + return self + + def __next__(self): + if not self.sample_list: + raise CrossSync.StopIteration + value = self.sample_list.pop(0) + if isinstance(value, Exception): + raise value + return value + + async def __anext__(self): + return self.__next__() + + return MockStream(sample_list) + + def resonse_with_metadata(self): + from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse + + schema = {"a": "string_type", "b": "int64_type"} + return ExecuteQueryResponse( + { + "metadata": { + "proto_schema": { + "columns": [ + {"name": name, "type_": {_type: {}}} + for name, _type in schema.items() + ] + } + } + } + ) + + def resonse_with_result(self, *args, resume_token=None): + from google.cloud.bigtable_v2.types.data import ProtoRows, Value as PBValue + from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse + + if resume_token is None: + resume_token_dict = {} + else: + resume_token_dict = {"resume_token": resume_token} + + values = [] + for column_value in args: + if column_value is None: + pb_value = PBValue({}) + else: + pb_value = PBValue( + { + "int_value" + if isinstance(column_value, int) + else "string_value": column_value + } + ) + values.append(pb_value) + rows = ProtoRows(values=values) + + return ExecuteQueryResponse( + { + "results": { + "proto_rows_batch": { + "batch_data": ProtoRows.serialize(rows), + }, + **resume_token_dict, + } + } + ) + + @CrossSync.pytest + async def test_execute_query(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r async for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert results[1]["a"] == "test2" + assert results[1]["b"] == 9 + assert results[2]["a"] == "test3" + assert results[2]["b"] is None + assert execute_query_mock.call_count == 1 + + @CrossSync.pytest + async def test_execute_query_with_params(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME} WHERE b=@b", + self.INSTANCE_NAME, + parameters={"b": 9}, + ) + results = [r async for r in result] + assert len(results) == 1 + assert results[0]["a"] == "test2" + assert results[0]["b"] == 9 + assert execute_query_mock.call_count == 1 + + @CrossSync.pytest + async def test_execute_query_error_before_metadata(self): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + DeadlineExceeded(""), + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r async for r in result] + assert len(results) == 3 + assert execute_query_mock.call_count == 2 + + @CrossSync.pytest + async def test_execute_query_error_after_metadata(self): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + self.resonse_with_metadata(), + DeadlineExceeded(""), + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r async for r in result] + assert len(results) == 3 + assert execute_query_mock.call_count == 2 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [] + + @CrossSync.pytest + async def test_execute_query_with_retries(self): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + DeadlineExceeded(""), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + DeadlineExceeded(""), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r async for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert results[1]["a"] == "test2" + assert results[1]["b"] == 9 + assert results[2]["a"] == "test3" + assert results[2]["b"] is None + assert len(results) == 3 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [b"r1", b"r2"] + + @pytest.mark.parametrize( + "exception", + [ + (core_exceptions.DeadlineExceeded("")), + (core_exceptions.Aborted("")), + (core_exceptions.ServiceUnavailable("")), + ], + ) + @CrossSync.pytest + async def test_execute_query_retryable_error(self, exception): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test", resume_token=b"t1"), + exception, + self.resonse_with_result(8, resume_token=b"t2"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r async for r in result] + assert len(results) == 1 + assert execute_query_mock.call_count == 2 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [b"t1"] + + @CrossSync.pytest + async def test_execute_query_retry_partial_row(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test", resume_token=b"t1"), + core_exceptions.DeadlineExceeded(""), + self.resonse_with_result(8, resume_token=b"t2"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r async for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert execute_query_mock.call_count == 2 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [b"t1"] + + @pytest.mark.parametrize( + "ExceptionType", + [ + (core_exceptions.InvalidArgument), + (core_exceptions.FailedPrecondition), + (core_exceptions.PermissionDenied), + (core_exceptions.MethodNotImplemented), + (core_exceptions.Cancelled), + (core_exceptions.AlreadyExists), + (core_exceptions.OutOfRange), + (core_exceptions.DataLoss), + (core_exceptions.Unauthenticated), + (core_exceptions.NotFound), + (core_exceptions.ResourceExhausted), + (core_exceptions.Unknown), + (core_exceptions.InternalServerError), + ], + ) + @CrossSync.pytest + async def test_execute_query_non_retryable(self, ExceptionType): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + ExceptionType(""), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + + result = await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + r = await CrossSync.next(result) + assert r["a"] == "test" + assert r["b"] == 8 + + with pytest.raises(ExceptionType): + r = await CrossSync.next(result) + + assert execute_query_mock.call_count == 1 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [] + + @CrossSync.pytest + async def test_execute_query_metadata_received_multiple_times_detected(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_metadata(), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + with pytest.raises( + Exception, match="Invalid ExecuteQuery response received" + ): + [ + r + async for r in await client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + ] From f73498cbc6bfc517a66f2c09e27271228d6d4d88 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 30 Aug 2024 16:41:41 -0700 Subject: [PATCH 226/360] ran blacken --- google/cloud/bigtable/data/_async/client.py | 18 +++++++++-- .../_async/execute_query_iterator.py | 4 ++- .../_async/test_query_iterator.py | 31 +++++++++++++------ 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 340a948dd..2b58d83cc 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -375,7 +375,12 @@ async def _manage_channel( next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = next_refresh - (time.monotonic() - start_timestamp) - @CrossSync.convert(replace_symbols={"TableAsync": "Table", "ExecuteQueryIteratorAsync": "ExecuteQueryIterator"}) + @CrossSync.convert( + replace_symbols={ + "TableAsync": "Table", + "ExecuteQueryIteratorAsync": "ExecuteQueryIterator", + } + ) async def _register_instance( self, instance_id: str, owner: TableAsync | ExecuteQueryIteratorAsync ) -> None: @@ -410,7 +415,12 @@ async def _register_instance( # refresh tasks aren't active. start them as background tasks self._start_background_channel_refresh() - @CrossSync.convert(replace_symbols={"TableAsync": "Table", "ExecuteQueryIteratorAsync": "ExecuteQueryIterator"}) + @CrossSync.convert( + replace_symbols={ + "TableAsync": "Table", + "ExecuteQueryIteratorAsync": "ExecuteQueryIterator", + } + ) async def _remove_instance_registration( self, instance_id: str, owner: TableAsync | ExecuteQueryIteratorAsync ) -> bool: @@ -483,7 +493,9 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs """ return TableAsync(self, instance_id, table_id, *args, **kwargs) - @CrossSync.convert(replace_symbols={"ExecuteQueryIteratorAsync": "ExecuteQueryIterator"}) + @CrossSync.convert( + replace_symbols={"ExecuteQueryIteratorAsync": "ExecuteQueryIterator"} + ) async def execute_query( self, query: str, diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index 9d2832d10..c293282d7 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -53,7 +53,9 @@ from google.cloud.bigtable.data import BigtableDataClient as DataClientType -@CrossSync.export_sync(path="google.cloud.bigtable.data.execute_query._sync.execute_query_iterator.ExecuteQueryIterator") +@CrossSync.export_sync( + path="google.cloud.bigtable.data.execute_query._sync.execute_query_iterator.ExecuteQueryIterator" +) class ExecuteQueryIteratorAsync: """ ExecuteQueryIteratorAsync handles collecting streaming responses from the diff --git a/tests/unit/data/execute_query/_async/test_query_iterator.py b/tests/unit/data/execute_query/_async/test_query_iterator.py index 08a7e0711..f55b2f9b5 100644 --- a/tests/unit/data/execute_query/_async/test_query_iterator.py +++ b/tests/unit/data/execute_query/_async/test_query_iterator.py @@ -30,7 +30,9 @@ import mock # type: ignore -@CrossSync.export_sync(path="tests.unit.data.execute_query._sync.test_query_iterator.MockIterator") +@CrossSync.export_sync( + path="tests.unit.data.execute_query._sync.test_query_iterator.MockIterator" +) class MockIterator: def __init__(self, values, delay=None): self._values = values @@ -52,9 +54,10 @@ async def __anext__(self): return value -@CrossSync.export_sync(path="tests.unit.data.execute_query._sync.test_query_iterator.TestQueryIterator") +@CrossSync.export_sync( + path="tests.unit.data.execute_query._sync.test_query_iterator.TestQueryIterator" +) class TestQueryIteratorAsync: - @staticmethod def _target_class(): return CrossSync.ExecuteQueryIterator @@ -87,15 +90,21 @@ def proto_byte_stream(self): } } ), - ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[0]}}), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[0]}} + ), ExecuteQueryResponse( results={ "proto_rows_batch": {"batch_data": messages[1]}, "resume_token": b"token1", } ), - ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[2]}}), - ExecuteQueryResponse(results={"proto_rows_batch": {"batch_data": messages[3]}}), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[2]}} + ), + ExecuteQueryResponse( + results={"proto_rows_batch": {"batch_data": messages[3]}} + ), ExecuteQueryResponse( results={ "proto_rows_batch": {"batch_data": messages[4]}, @@ -111,7 +120,6 @@ def proto_byte_stream(self): ] return stream - @CrossSync.pytest async def test_iterator(self, proto_byte_stream): client_mock = mock.Mock() @@ -123,7 +131,9 @@ async def test_iterator(self, proto_byte_stream): iterator = None with mock.patch.object( - CrossSync, "retry_target_stream", return_value=mock_async_iterator, + CrossSync, + "retry_target_stream", + return_value=mock_async_iterator, ): iterator = self._make_one( client=client_mock, @@ -146,7 +156,6 @@ async def test_iterator(self, proto_byte_stream): assert mock_async_iterator.idx == len(proto_byte_stream) - @CrossSync.pytest async def test_iterator_awaits_metadata(self, proto_byte_stream): client_mock = mock.Mock() @@ -156,7 +165,9 @@ async def test_iterator_awaits_metadata(self, proto_byte_stream): mock_async_iterator = MockIterator(proto_byte_stream) iterator = None with mock.patch.object( - CrossSync, "retry_target_stream", return_value=mock_async_iterator, + CrossSync, + "retry_target_stream", + return_value=mock_async_iterator, ): iterator = self._make_one( client=client_mock, From d884e5b42a523a926729d54a1cac06dab9c00133 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 30 Aug 2024 16:43:36 -0700 Subject: [PATCH 227/360] added execute_query tests --- tests/unit/data/_sync/test_client.py | 335 +++++++++++++++++++++++++++ 1 file changed, 335 insertions(+) diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 807a4914f..7de47de5d 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -2742,3 +2742,338 @@ def test_read_modify_write_row_building(self): table.read_modify_write_row("key", mock.Mock()) assert constructor_mock.call_count == 1 constructor_mock.assert_called_once_with(mock_response.row) + + +class TestExecuteQuery: + TABLE_NAME = "TABLE_NAME" + INSTANCE_NAME = "INSTANCE_NAME" + + def _make_client(self, *args, **kwargs): + return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) + + def _make_gapic_stream(self, sample_list: list["ExecuteQueryResponse" | Exception]): + class MockStream: + def __init__(self, sample_list): + self.sample_list = sample_list + + def __aiter__(self): + return self + + def __iter__(self): + return self + + def __next__(self): + if not self.sample_list: + raise CrossSync._Sync_Impl.StopIteration + value = self.sample_list.pop(0) + if isinstance(value, Exception): + raise value + return value + + async def __anext__(self): + return self.__next__() + + return MockStream(sample_list) + + def resonse_with_metadata(self): + from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse + + schema = {"a": "string_type", "b": "int64_type"} + return ExecuteQueryResponse( + { + "metadata": { + "proto_schema": { + "columns": [ + {"name": name, "type_": {_type: {}}} + for name, _type in schema.items() + ] + } + } + } + ) + + def resonse_with_result(self, *args, resume_token=None): + from google.cloud.bigtable_v2.types.data import ProtoRows, Value as PBValue + from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse + + if resume_token is None: + resume_token_dict = {} + else: + resume_token_dict = {"resume_token": resume_token} + values = [] + for column_value in args: + if column_value is None: + pb_value = PBValue({}) + else: + pb_value = PBValue( + { + "int_value" + if isinstance(column_value, int) + else "string_value": column_value + } + ) + values.append(pb_value) + rows = ProtoRows(values=values) + return ExecuteQueryResponse( + { + "results": { + "proto_rows_batch": {"batch_data": ProtoRows.serialize(rows)}, + **resume_token_dict, + } + } + ) + + def test_execute_query(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert results[1]["a"] == "test2" + assert results[1]["b"] == 9 + assert results[2]["a"] == "test3" + assert results[2]["b"] is None + assert execute_query_mock.call_count == 1 + + def test_execute_query_with_params(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME} WHERE b=@b", + self.INSTANCE_NAME, + parameters={"b": 9}, + ) + results = [r for r in result] + assert len(results) == 1 + assert results[0]["a"] == "test2" + assert results[0]["b"] == 9 + assert execute_query_mock.call_count == 1 + + def test_execute_query_error_before_metadata(self): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + DeadlineExceeded(""), + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r for r in result] + assert len(results) == 3 + assert execute_query_mock.call_count == 2 + + def test_execute_query_error_after_metadata(self): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + self.resonse_with_metadata(), + DeadlineExceeded(""), + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r for r in result] + assert len(results) == 3 + assert execute_query_mock.call_count == 2 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [] + + def test_execute_query_with_retries(self): + from google.api_core.exceptions import DeadlineExceeded + + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + DeadlineExceeded(""), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + DeadlineExceeded(""), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert results[1]["a"] == "test2" + assert results[1]["b"] == 9 + assert results[2]["a"] == "test3" + assert results[2]["b"] is None + assert len(results) == 3 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [b"r1", b"r2"] + + @pytest.mark.parametrize( + "exception", + [ + core_exceptions.DeadlineExceeded(""), + core_exceptions.Aborted(""), + core_exceptions.ServiceUnavailable(""), + ], + ) + def test_execute_query_retryable_error(self, exception): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test", resume_token=b"t1"), + exception, + self.resonse_with_result(8, resume_token=b"t2"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r for r in result] + assert len(results) == 1 + assert execute_query_mock.call_count == 2 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [b"t1"] + + def test_execute_query_retry_partial_row(self): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test", resume_token=b"t1"), + core_exceptions.DeadlineExceeded(""), + self.resonse_with_result(8, resume_token=b"t2"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + results = [r for r in result] + assert results[0]["a"] == "test" + assert results[0]["b"] == 8 + assert execute_query_mock.call_count == 2 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [b"t1"] + + @pytest.mark.parametrize( + "ExceptionType", + [ + core_exceptions.InvalidArgument, + core_exceptions.FailedPrecondition, + core_exceptions.PermissionDenied, + core_exceptions.MethodNotImplemented, + core_exceptions.Cancelled, + core_exceptions.AlreadyExists, + core_exceptions.OutOfRange, + core_exceptions.DataLoss, + core_exceptions.Unauthenticated, + core_exceptions.NotFound, + core_exceptions.ResourceExhausted, + core_exceptions.Unknown, + core_exceptions.InternalServerError, + ], + ) + def test_execute_query_non_retryable(self, ExceptionType): + values = [ + self.resonse_with_metadata(), + self.resonse_with_result("test"), + self.resonse_with_result(8, resume_token=b"r1"), + ExceptionType(""), + self.resonse_with_result("test2"), + self.resonse_with_result(9, resume_token=b"r2"), + self.resonse_with_result("test3"), + self.resonse_with_result(None, resume_token=b"r3"), + ] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + result = client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + r = CrossSync._Sync_Impl.next(result) + assert r["a"] == "test" + assert r["b"] == 8 + with pytest.raises(ExceptionType): + r = CrossSync._Sync_Impl.next(result) + assert execute_query_mock.call_count == 1 + requests = [args[0][0] for args in execute_query_mock.call_args_list] + resume_tokens = [r.resume_token for r in requests if r.resume_token] + assert resume_tokens == [] + + def test_execute_query_metadata_received_multiple_times_detected(self): + values = [self.resonse_with_metadata(), self.resonse_with_metadata()] + client = self._make_client() + with mock.patch.object( + client._gapic_client, "execute_query", CrossSync._Sync_Impl.Mock() + ) as execute_query_mock: + execute_query_mock.return_value = self._make_gapic_stream(values) + with pytest.raises( + Exception, match="Invalid ExecuteQuery response received" + ): + [ + r + for r in client.execute_query( + f"SELECT a, b FROM {self.TABLE_NAME}", self.INSTANCE_NAME + ) + ] From 18854f03aa0bd672b10ebb59ce0b9322b829c071 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 3 Sep 2024 15:42:33 -0700 Subject: [PATCH 228/360] added docstring templating --- .../data/_sync/cross_sync/_decorators.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index 2788ffec4..bf4d855de 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -160,6 +160,8 @@ def _convert_ast_to_py(cls, ast_node: ast.expr | None) -> Any: return ast_node.value if isinstance(ast_node, ast.List): return [cls._convert_ast_to_py(node) for node in ast_node.elts] + if isinstance(ast_node, ast.Tuple): + return tuple(cls._convert_ast_to_py(node) for node in ast_node.elts) if isinstance(ast_node, ast.Dict): return { cls._convert_ast_to_py(k): cls._convert_ast_to_py(v) @@ -175,6 +177,7 @@ class ExportSync(AstDecorator): Args: path: path to output the generated sync class replace_symbols: a dict of symbols and replacements to use when generating sync class + docstring_format_vars: a dict of variables to replace in the docstring mypy_ignore: set of mypy errors to ignore in the generated file include_file_imports: if True, include top-level imports from the file in the generated sync class add_mapping_for_name: when given, will add a new attribute to CrossSync, so the original class and its sync version can be accessed from CrossSync. @@ -185,12 +188,16 @@ def __init__( path: str, *, replace_symbols: dict[str, str] | None = None, + docstring_format_vars: dict[str, tuple[str, str]] | None = None, mypy_ignore: Sequence[str] = (), include_file_imports: bool = True, add_mapping_for_name: str | None = None, ): self.path = path self.replace_symbols = replace_symbols + docstring_format_vars = docstring_format_vars or {} + self.async_docstring_format_vars = {k: v[0] for k, v in docstring_format_vars.items()} + self.sync_docstring_format_vars = {k: v[1] for k, v in docstring_format_vars.items()} self.mypy_ignore = mypy_ignore self.include_file_imports = include_file_imports self.add_mapping_for_name = add_mapping_for_name @@ -206,6 +213,8 @@ def async_decorator(self): def decorator(cls): if new_mapping: CrossSync.add_mapping(new_mapping, cls) + if self.async_docstring_format_vars: + cls.__doc__ = cls.__doc__.format(**self.async_docstring_format_vars) return cls return decorator @@ -257,6 +266,9 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): wrapped_node = transformers_globals["CrossSyncMethodDecoratorHandler"]().visit( wrapped_node ) + if self.sync_docstring_format_vars: + docstring = ast.get_docstring(wrapped_node) + wrapped_node.body[0].value.s = docstring.format(**self.sync_docstring_format_vars) return wrapped_node @@ -267,6 +279,7 @@ class Convert(AstDecorator): Args: sync_name: use a new name for the sync method replace_symbols: a dict of symbols and replacements to use when generating sync method + docstring_format_vars: a dict of variables to replace in the docstring rm_aio: if True, automatically strip all asyncio keywords from method. If False, only the signature `async def` is stripped. Other keywords must be wrapped in CrossSync.rm_aio() calls to be removed. @@ -277,10 +290,14 @@ def __init__( *, sync_name: str | None = None, replace_symbols: dict[str, str] | None = None, + docstring_format_vars: dict[str, tuple[str, str]] | None = None, rm_aio: bool = False, ): self.sync_name = sync_name self.replace_symbols = replace_symbols + docstring_format_vars = docstring_format_vars or {} + self.async_docstring_format_vars = {k: v[0] for k, v in docstring_format_vars.items()} + self.sync_docstring_format_vars = {k: v[1] for k, v in docstring_format_vars.items()} self.rm_aio = rm_aio def sync_ast_transform(self, wrapped_node, transformers_globals): @@ -310,8 +327,25 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): if self.replace_symbols: replacer = transformers_globals["SymbolReplacer"] wrapped_node = replacer(self.replace_symbols).visit(wrapped_node) + # update docstring if specified + if self.sync_docstring_format_vars: + docstring = ast.get_docstring(wrapped_node) + wrapped_node.body[0].value.s = docstring.format(**self.sync_docstring_format_vars) return wrapped_node + def async_decorator(self): + """ + If docstring_format_vars are provided, update the docstring of the async method + """ + + if self.async_docstring_format_vars: + def decorator(f): + f.__doc__ = f.__doc__.format(**self.async_docstring_format_vars) + return f + return decorator + else: + return None + class DropMethod(AstDecorator): """ From 5396bc451fd19f98ca7dd029d5a663fd832e5c04 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 3 Sep 2024 15:43:50 -0700 Subject: [PATCH 229/360] use templating in docstrings --- google/cloud/bigtable/data/_async/client.py | 40 ++++++++++---- .../_async/execute_query_iterator.py | 54 ++++++++++--------- 2 files changed, 58 insertions(+), 36 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 2b58d83cc..827db9839 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -112,7 +112,12 @@ add_mapping_for_name="DataClient", ) class BigtableDataClientAsync(ClientWithProject): - @CrossSync.convert + @CrossSync.convert( + docstring_format_vars={ + "LOOP_MESSAGE": ("Client should be created within an async context (running event loop)", ""), + "RAISE_NO_LOOP": ("RuntimeError: if called outside of an async context (no running event loop)", ""), + } + ) def __init__( self, *, @@ -126,7 +131,7 @@ def __init__( """ Create a client instance for the Bigtable Data API - Client should be created within an async context (running event loop) + {LOOP_MESSAGE} Args: project: the project which the client acts on behalf of. @@ -143,8 +148,8 @@ def __init__( Client options used to set user options on the client. API Endpoint should be set through client_options. Raises: - RuntimeError: if called outside of an async context (no running event loop) ValueError: if pool_size is less than 1 + {RAISE_NO_LOOP} """ # set up transport in registry transport_str = f"bt-{self._client_version()}-{pool_size}" @@ -228,12 +233,15 @@ def _client_version() -> str: version_str += "-async" return version_str + @CrossSync.convert( + docstring_format_vars={"RAISE_NO_LOOP": ("RuntimeError: if not called in an asyncio event loop", "None")} + ) def _start_background_channel_refresh(self) -> None: """ Starts a background task to ping and warm each channel in the pool Raises: - RuntimeError: if not called in an asyncio event loop + {RAISE_NO_LOOP} """ if ( not self._channel_refresh_tasks @@ -320,7 +328,7 @@ async def _manage_channel( grace_period: float = 60 * 10, ) -> None: """ - Background coroutine that periodically refreshes and warms a grpc channel + Background task that periodically refreshes and warms a grpc channel The backend will automatically close channels after 60 minutes, so `refresh_interval` + `grace_period` should be < 60 minutes @@ -451,12 +459,20 @@ async def _remove_instance_registration( except KeyError: return False - @CrossSync.convert(replace_symbols={"TableAsync": "Table"}) + @CrossSync.convert( + replace_symbols={"TableAsync": "Table"}, + docstring_format_vars={ + "LOOP_MESSAGE": ("Must be created within an async context (running event loop)", ""), + "RAISE_NO_LOOP": ("RuntimeError: if called outside of an async context (no running event loop)", "None"), + }, + ) def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: """ Returns a table instance for making data API requests. All arguments are passed directly to the TableAsync constructor. + {LOOP_MESSAGE} + Args: instance_id: The Bigtable instance ID to associate with this client. instance_id is combined with the client's project to fully @@ -489,7 +505,7 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAs Returns: TableAsync: a table instance for making data API requests Raises: - RuntimeError: if called outside of an async context (no running event loop) + {RAISE_NO_LOOP} """ return TableAsync(self, instance_id, table_id, *args, **kwargs) @@ -615,7 +631,11 @@ class TableAsync: """ @CrossSync.convert( - replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"} + replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}, + docstring_format_vars={ + "LOOP_MESSAGE": ("Must be created within an async context (running event loop)", ""), + "RAISE_NO_LOOP": ("RuntimeError: if called outside of an async context (no running event loop)", "None"), + } ) def __init__( self, @@ -647,7 +667,7 @@ def __init__( """ Initialize a Table instance - Must be created within an async context (running event loop) + {LOOP_MESSAGE} Args: instance_id: The Bigtable instance ID to associate with this client. @@ -679,7 +699,7 @@ def __init__( encountered during all other operations. Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) Raises: - RuntimeError: if called outside of an async context (no running event loop) + {RAISE_NO_LOOP} """ # NOTE: any changes to the signature of this method should also be reflected # in client.get_table() diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index c293282d7..43f68a926 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -54,35 +54,16 @@ @CrossSync.export_sync( - path="google.cloud.bigtable.data.execute_query._sync.execute_query_iterator.ExecuteQueryIterator" + path="google.cloud.bigtable.data.execute_query._sync.execute_query_iterator.ExecuteQueryIterator", ) class ExecuteQueryIteratorAsync: - """ - ExecuteQueryIteratorAsync handles collecting streaming responses from the - ExecuteQuery RPC and parsing them to QueryResultRows. - - ExecuteQueryIteratorAsync implements Asynchronous Iterator interface and can - be used with "async for" syntax. It is also a context manager. - - It is **not thread-safe**. It should not be used by multiple asyncio Tasks. - - Args: - client: bigtable client - instance_id: id of the instance on which the query is executed - request_body: dict representing the body of the ExecuteQueryRequest - attempt_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to 600 seconds. - operation_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the 20 seconds. If None, defaults to operation_timeout. - req_metadata: metadata used while sending the gRPC request - retryable_excs: a list of errors that will be retried if encountered. - Raises: - RuntimeError: if the instance is not created within an async event loop context. - """ + @CrossSync.convert( + docstring_format_vars={ + "NO_LOOP": ("RuntimeError: if the instance is not created within an async event loop context.", "None"), + "TASK_OR_THREAD": ("asyncio Tasks", "threads"), + } + ) def __init__( self, client: DataClientType, @@ -94,6 +75,27 @@ def __init__( req_metadata: Sequence[Tuple[str, str]], retryable_excs: List[type[Exception]], ) -> None: + """ + Collects responses from ExecuteQuery requests and parses them into QueryResultRows. + + It is **not thread-safe**. It should not be used by multiple {TASK_OR_THREAD}. + + Args: + client: bigtable client + instance_id: id of the instance on which the query is executed + request_body: dict representing the body of the ExecuteQueryRequest + attempt_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to 600 seconds. + operation_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the 20 seconds. If None, defaults to operation_timeout. + req_metadata: metadata used while sending the gRPC request + retryable_excs: a list of errors that will be retried if encountered. + Raises: + {NO_LOOP} + """ self._table_name = None self._app_profile_id = app_profile_id self._client = client From 1433136da785fc6a2ec09e89bbbf14fe4d46de1e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 3 Sep 2024 15:46:00 -0700 Subject: [PATCH 230/360] regenerated files --- google/cloud/bigtable/data/_sync/client.py | 72 +++++++++++-------- .../bigtable/data/_sync/mutations_batcher.py | 5 +- .../_sync/execute_query_iterator.py | 44 +++++------- tests/system/data/test_system.py | 1 + tests/unit/data/_sync/test__mutate_rows.py | 1 + tests/unit/data/_sync/test__read_rows.py | 3 + tests/unit/data/_sync/test_client.py | 33 ++++++--- .../unit/data/_sync/test_mutations_batcher.py | 2 + .../data/_sync/test_read_rows_acceptance.py | 6 ++ .../_sync/test_query_iterator.py | 2 + 10 files changed, 100 insertions(+), 69 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index 8274e922d..bee3d5c7e 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -84,19 +84,20 @@ @CrossSync._Sync_Impl.add_mapping_decorator("DataClient") class BigtableDataClient(ClientWithProject): + def __init__( self, *, project: str | None = None, pool_size: int = 3, credentials: google.auth.credentials.Credentials | None = None, - client_options: dict[str, Any] - | "google.api_core.client_options.ClientOptions" - | None = None, + client_options: ( + dict[str, Any] | "google.api_core.client_options.ClientOptions" | None + ) = None, ): """Create a client instance for the Bigtable Data API - Client should be created within an async context (running event loop) + Args: project: the project which the client acts on behalf of. @@ -113,8 +114,8 @@ def __init__( Client options used to set user options on the client. API Endpoint should be set through client_options. Raises: - RuntimeError: if called outside of an async context (no running event loop) - ValueError: if pool_size is less than 1""" + ValueError: if pool_size is less than 1 + """ transport_str = f"bt-{self._client_version()}-{pool_size}" transport = PooledTransportType.with_fixed_size(pool_size) BigtableClientMeta._transport_registry[transport_str] = transport @@ -187,7 +188,7 @@ def _start_background_channel_refresh(self) -> None: """Starts a background task to ping and warm each channel in the pool Raises: - RuntimeError: if not called in an asyncio event loop""" + None""" if ( not self._channel_refresh_tasks and (not self._emulator_host) @@ -260,7 +261,7 @@ def _manage_channel( refresh_interval_max: float = 60 * 45, grace_period: float = 60 * 10, ) -> None: - """Background coroutine that periodically refreshes and warms a grpc channel + """Background task that periodically refreshes and warms a grpc channel The backend will automatically close channels after 60 minutes, so `refresh_interval` + `grace_period` should be < 60 minutes @@ -361,6 +362,8 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> Table: """Returns a table instance for making data API requests. All arguments are passed directly to the Table constructor. + + Args: instance_id: The Bigtable instance ID to associate with this client. instance_id is combined with the client's project to fully @@ -393,8 +396,7 @@ def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> Table: Returns: Table: a table instance for making data API requests Raises: - RuntimeError: if called outside of an async context (no running event loop) - """ + None""" return Table(self, instance_id, table_id, *args, **kwargs) def execute_query( @@ -531,7 +533,7 @@ def __init__( ): """Initialize a Table instance - Must be created within an async context (running event loop) + Args: instance_id: The Bigtable instance ID to associate with this client. @@ -563,8 +565,7 @@ def __init__( encountered during all other operations. Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) Raises: - RuntimeError: if called outside of an async context (no running event loop) - """ + None""" _validate_timeouts( default_operation_timeout, default_attempt_timeout, allow_none=True ) @@ -621,8 +622,9 @@ def read_rows_stream( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: ( + Sequence[type[Exception]] | TABLE_DEFAULT + ) = TABLE_DEFAULT.READ_ROWS, ) -> Iterable[Row]: """Read a set of rows from the table, based on the specified query. Returns an iterator to asynchronously stream back row data. @@ -669,8 +671,9 @@ def read_rows( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: ( + Sequence[type[Exception]] | TABLE_DEFAULT + ) = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: """Read a set of rows from the table, based on the specified query. Retruns results as a list of Row objects when the request is complete. @@ -716,8 +719,9 @@ def read_row( row_filter: RowFilter | None = None, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: ( + Sequence[type[Exception]] | TABLE_DEFAULT + ) = TABLE_DEFAULT.READ_ROWS, ) -> Row | None: """Read a single row from the table, based on the specified key. @@ -763,8 +767,9 @@ def read_rows_sharded( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: ( + Sequence[type[Exception]] | TABLE_DEFAULT + ) = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: """Runs a sharded query in parallel, then return the results in a single list. Results will be returned in the order of the input queries. @@ -852,8 +857,9 @@ def row_exists( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: ( + Sequence[type[Exception]] | TABLE_DEFAULT + ) = TABLE_DEFAULT.READ_ROWS, ) -> bool: """Return a boolean indicating whether the specified row exists in the table. uses the filters: chain(limit cells per row = 1, strip value) @@ -897,8 +903,9 @@ def sample_row_keys( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: ( + Sequence[type[Exception]] | TABLE_DEFAULT + ) = TABLE_DEFAULT.DEFAULT, ) -> RowKeySamples: """Return a set of RowKeySamples that delimit contiguous sections of the table of approximately equal size @@ -970,8 +977,9 @@ def mutations_batcher( flow_control_max_bytes: int = 100 * _MB_SIZE, batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - batch_retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: ( + Sequence[type[Exception]] | TABLE_DEFAULT + ) = TABLE_DEFAULT.MUTATE_ROWS, ) -> MutationsBatcher: """Returns a new mutations batcher instance. @@ -1015,8 +1023,9 @@ def mutate_row( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: ( + Sequence[type[Exception]] | TABLE_DEFAULT + ) = TABLE_DEFAULT.DEFAULT, ): """Mutates a row atomically. @@ -1086,8 +1095,9 @@ def bulk_mutate_rows( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + retryable_errors: ( + Sequence[type[Exception]] | TABLE_DEFAULT + ) = TABLE_DEFAULT.MUTATE_ROWS, ): """Applies mutations for multiple rows in a single batched request. diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index e9264c316..dfd889d0e 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -194,8 +194,9 @@ def __init__( flow_control_max_bytes: int = 100 * _MB_SIZE, batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - batch_retryable_errors: Sequence[type[Exception]] - | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: ( + Sequence[type[Exception]] | TABLE_DEFAULT + ) = TABLE_DEFAULT.MUTATE_ROWS, ): self._operation_timeout, self._attempt_timeout = _get_timeouts( batch_operation_timeout, batch_attempt_timeout, table diff --git a/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py index 9afb59bef..691675c53 100644 --- a/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py @@ -41,31 +41,6 @@ class ExecuteQueryIterator: - """ - ExecuteQueryIteratorAsync handles collecting streaming responses from the - ExecuteQuery RPC and parsing them to QueryResultRows. - - ExecuteQueryIteratorAsync implements Asynchronous Iterator interface and can - be used with "async for" syntax. It is also a context manager. - - It is **not thread-safe**. It should not be used by multiple asyncio Tasks. - - Args: - client: bigtable client - instance_id: id of the instance on which the query is executed - request_body: dict representing the body of the ExecuteQueryRequest - attempt_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to 600 seconds. - operation_timeout: the time budget for an individual network request, in seconds. - If it takes longer than this time to complete, the request will be cancelled with - a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the 20 seconds. If None, defaults to operation_timeout. - req_metadata: metadata used while sending the gRPC request - retryable_excs: a list of errors that will be retried if encountered. - Raises: - RuntimeError: if the instance is not created within an async event loop context. - """ def __init__( self, @@ -78,6 +53,25 @@ def __init__( req_metadata: Sequence[Tuple[str, str]], retryable_excs: List[type[Exception]], ) -> None: + """Collects responses from ExecuteQuery requests and parses them into QueryResultRows. + + It is **not thread-safe**. It should not be used by multiple threads. + + Args: + client: bigtable client + instance_id: id of the instance on which the query is executed + request_body: dict representing the body of the ExecuteQueryRequest + attempt_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget. + Defaults to 600 seconds. + operation_timeout: the time budget for an individual network request, in seconds. + If it takes longer than this time to complete, the request will be cancelled with + a DeadlineExceeded exception, and a retry will be attempted. + Defaults to the 20 seconds. If None, defaults to operation_timeout. + req_metadata: metadata used while sending the gRPC request + retryable_excs: a list of errors that will be retried if encountered. + Raises: + None""" self._table_name = None self._app_profile_id = app_profile_id self._client = client diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py index ac586ea47..32e24463b 100644 --- a/tests/system/data/test_system.py +++ b/tests/system/data/test_system.py @@ -70,6 +70,7 @@ def delete_rows(self): class TestSystem: + @pytest.fixture(scope="session") def client(self): project = os.getenv("GOOGLE_CLOUD_PROJECT") or None diff --git a/tests/unit/data/_sync/test__mutate_rows.py b/tests/unit/data/_sync/test__mutate_rows.py index d394ff954..73c714246 100644 --- a/tests/unit/data/_sync/test__mutate_rows.py +++ b/tests/unit/data/_sync/test__mutate_rows.py @@ -27,6 +27,7 @@ class TestMutateRowsOperation: + def _target_class(self): return CrossSync._Sync_Impl._MutateRowsOperation diff --git a/tests/unit/data/_sync/test__read_rows.py b/tests/unit/data/_sync/test__read_rows.py index 556bdf0ec..a71b1bf2b 100644 --- a/tests/unit/data/_sync/test__read_rows.py +++ b/tests/unit/data/_sync/test__read_rows.py @@ -237,6 +237,7 @@ def test_revise_limit(self, start_limit, emit_num, expected_limit): from google.cloud.bigtable_v2.types import ReadRowsResponse def awaitable_stream(): + def mock_stream(): for i in range(emit_num): yield ReadRowsResponse( @@ -272,6 +273,7 @@ def test_revise_limit_over_limit(self, start_limit, emit_num): from google.cloud.bigtable.data.exceptions import InvalidChunk def awaitable_stream(): + def mock_stream(): for i in range(emit_num): yield ReadRowsResponse( @@ -330,6 +332,7 @@ def test_retryable_ignore_repeated_rows(self): row_key = b"duplicate" def mock_awaitable_stream(): + def mock_stream(): while True: yield ReadRowsResponse( diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 7de47de5d..570786796 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -46,6 +46,7 @@ @CrossSync._Sync_Impl.add_mapping_decorator("TestBigtableDataClient") class TestBigtableDataClient: + @staticmethod def _get_target_class(): return CrossSync._Sync_Impl.DataClient @@ -316,11 +317,9 @@ def test__ping_and_warm_instances(self): gather.assert_awaited_once() grpc_call_args = channel.unary_unary().call_args_list for idx, (_, kwargs) in enumerate(grpc_call_args): - ( - expected_instance, - expected_table, - expected_app_profile, - ) = client_mock._active_instances[idx] + expected_instance, expected_table, expected_app_profile = ( + client_mock._active_instances[idx] + ) request = kwargs["request"] assert request["name"] == expected_instance assert request["app_profile_id"] == expected_app_profile @@ -409,9 +408,9 @@ def test__manage_channel_ping_and_warm(self): ) with mock.patch.object(*sleep_tuple): client_mock.transport.replace_channel.side_effect = asyncio.CancelledError - ping_and_warm = ( - client_mock._ping_and_warm_instances - ) = CrossSync._Sync_Impl.Mock() + ping_and_warm = client_mock._ping_and_warm_instances = ( + CrossSync._Sync_Impl.Mock() + ) try: channel_idx = 1 self._get_target_class()._manage_channel(client_mock, channel_idx, 10) @@ -962,6 +961,7 @@ def test_context_manager(self): @CrossSync._Sync_Impl.add_mapping_decorator("TestTable") class TestTable: + def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -1275,6 +1275,7 @@ def _make_gapic_stream( from google.cloud.bigtable_v2 import ReadRowsResponse class mock_stream: + def __init__(self, chunk_list, sleep_time): self.chunk_list = chunk_list self.idx = -1 @@ -1672,6 +1673,7 @@ def test_row_exists(self, return_value, expected_result): class TestReadRowsSharded: + def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -1857,6 +1859,7 @@ def mock_call(*args, **kwargs): class TestSampleRowKeys: + def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -1995,6 +1998,7 @@ def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): class TestMutateRow: + def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -2147,6 +2151,7 @@ def test_mutate_row_no_mutations(self, mutations): class TestBulkMutateRows: + def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -2483,6 +2488,7 @@ def test_bulk_mutate_error_recovery(self): class TestCheckAndMutateRow: + def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -2627,6 +2633,7 @@ def test_check_and_mutate_mutations_parsing(self): class TestReadModifyWriteRow: + def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -2752,7 +2759,9 @@ def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) def _make_gapic_stream(self, sample_list: list["ExecuteQueryResponse" | Exception]): + class MockStream: + def __init__(self, sample_list): self.sample_list = sample_list @@ -2807,9 +2816,11 @@ def resonse_with_result(self, *args, resume_token=None): else: pb_value = PBValue( { - "int_value" - if isinstance(column_value, int) - else "string_value": column_value + ( + "int_value" + if isinstance(column_value, int) + else "string_value" + ): column_value } ) values.append(pb_value) diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py index 49cc3efeb..fe3792293 100644 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -29,6 +29,7 @@ class Test_FlowControl: + @staticmethod def _target_class(): return CrossSync._Sync_Impl._FlowControl @@ -258,6 +259,7 @@ def test_add_to_flow_oversize(self): class TestMutationsBatcher: + def _get_target_class(self): return CrossSync._Sync_Impl.MutationsBatcher diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync/test_read_rows_acceptance.py index dcdd7d66c..ccb4f42e0 100644 --- a/tests/unit/data/_sync/test_read_rows_acceptance.py +++ b/tests/unit/data/_sync/test_read_rows_acceptance.py @@ -27,6 +27,7 @@ class TestReadRowsAcceptance: + @staticmethod def _get_operation_class(): return CrossSync._Sync_Impl._ReadRowsOperation @@ -64,6 +65,7 @@ def _coro_wrapper(stream): return stream def _process_chunks(self, *chunks): + def _row_stream(): yield ReadRowsResponse(chunks=chunks) @@ -83,6 +85,7 @@ def _row_stream(): "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description ) def test_row_merger_scenario(self, test_case: ReadRowsTest): + def _scenerio_stream(): for chunk in test_case.chunks: yield ReadRowsResponse(chunks=[chunk]) @@ -116,10 +119,12 @@ def _scenerio_stream(): "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description ) def test_read_rows_scenario(self, test_case: ReadRowsTest): + def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): from google.cloud.bigtable_v2 import ReadRowsResponse class mock_stream: + def __init__(self, chunk_list): self.chunk_list = chunk_list self.idx = -1 @@ -175,6 +180,7 @@ def cancel(self): assert actual == expected def test_out_of_order_rows(self): + def _row_stream(): yield ReadRowsResponse(last_scanned_row_key=b"a") diff --git a/tests/unit/data/execute_query/_sync/test_query_iterator.py b/tests/unit/data/execute_query/_sync/test_query_iterator.py index b3eded81e..8e52a1d76 100644 --- a/tests/unit/data/execute_query/_sync/test_query_iterator.py +++ b/tests/unit/data/execute_query/_sync/test_query_iterator.py @@ -26,6 +26,7 @@ class MockIterator: + def __init__(self, values, delay=None): self._values = values self.idx = 0 @@ -45,6 +46,7 @@ def __next__(self): class TestQueryIterator: + @staticmethod def _target_class(): return CrossSync._Sync_Impl.ExecuteQueryIterator From 2dce3b94f42d215626d0eadd4963b303c177afbe Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 3 Sep 2024 16:50:13 -0700 Subject: [PATCH 231/360] added test file for cross sync --- tests/unit/data/_sync/test_cross_sync.py | 42 ++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/unit/data/_sync/test_cross_sync.py diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py new file mode 100644 index 000000000..66bc3a37d --- /dev/null +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -0,0 +1,42 @@ +import typing +import asyncio +import pytest +import threading +import concurrent.futures +import time +import queue +from google import api_core +from google.cloud.bigtable.data._sync.cross_sync.cross_sync import CrossSync, T + + +class TestCrossSync: + + @pytest.mark.parametrize( + "attr, async_version, sync_version", [ + ("is_async", True, False), + ("sleep", asyncio.sleep, time.sleep), + ("retry_target", api_core.retry.retry_target_async, api_core.retry.retry_target), + ("retry_target_stream", api_core.retry.retry_target_stream_async, api_core.retry.retry_target_stream), + ("Retry", api_core.retry.AsyncRetry, api_core.retry.Retry), + ("Queue", asyncio.Queue, queue.Queue), + ("Condition", asyncio.Condition, threading.Condition), + ("Future", asyncio.Future, concurrent.futures.Future), + ("Task", asyncio.Task, concurrent.futures.Future), + ("Event", asyncio.Event, threading.Event), + ("Semaphore", asyncio.Semaphore, threading.Semaphore), + ("StopIteration", StopAsyncIteration, StopIteration), + # types + ("Awaitable", typing.Awaitable, typing.Union[T]), + ("Iterable", typing.AsyncIterable, typing.Iterable), + ("Iterator", typing.AsyncIterator, typing.Iterator), + ("Generator", typing.AsyncGenerator, typing.Generator), + ] + ) + def test_alias_attributes(self, attr, async_version, sync_version): + """ + Test basic alias attributes, to ensure they point to the right place + in both sync and async versions. + """ + assert getattr(CrossSync, attr) == async_version, f"Failed async version for {attr}" + assert getattr(CrossSync._Sync_Impl, attr) == sync_version, f"Failed sync version for {attr}" + From ade18b19ead2dda7b6eb9da5d7506dcbaaa5abf3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 3 Sep 2024 17:33:47 -0700 Subject: [PATCH 232/360] added some tests --- tests/unit/data/_sync/test_cross_sync.py | 160 ++++++++++++++++++++++- 1 file changed, 156 insertions(+), 4 deletions(-) diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index 66bc3a37d..3f50b21f6 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -1,16 +1,31 @@ import typing import asyncio import pytest +import pytest_asyncio import threading import concurrent.futures import time import queue +import functools from google import api_core from google.cloud.bigtable.data._sync.cross_sync.cross_sync import CrossSync, T - +from unittest import mock class TestCrossSync: + async def async_iter(self, in_list): + for i in in_list: + yield i + + @pytest.fixture + def cs_sync(self): + return CrossSync._Sync_Impl + + @pytest_asyncio.fixture + def cs_async(self): + return CrossSync + + @pytest.mark.parametrize( "attr, async_version, sync_version", [ ("is_async", True, False), @@ -32,11 +47,148 @@ class TestCrossSync: ("Generator", typing.AsyncGenerator, typing.Generator), ] ) - def test_alias_attributes(self, attr, async_version, sync_version): + def test_alias_attributes(self, attr, async_version, sync_version, cs_sync, cs_async): """ Test basic alias attributes, to ensure they point to the right place in both sync and async versions. """ - assert getattr(CrossSync, attr) == async_version, f"Failed async version for {attr}" - assert getattr(CrossSync._Sync_Impl, attr) == sync_version, f"Failed sync version for {attr}" + assert getattr(cs_async, attr) == async_version, f"Failed async version for {attr}" + assert getattr(cs_sync, attr) == sync_version, f"Failed sync version for {attr}" + + @pytest.mark.asyncio + async def test_Mock(self, cs_sync, cs_async): + """ + Test Mock class in both sync and async versions + """ + assert isinstance(cs_async.Mock(), mock.AsyncMock) + assert isinstance(cs_sync.Mock(), mock.Mock) + # test with return value + assert await cs_async.Mock(return_value=1)() == 1 + assert cs_sync.Mock(return_value=1)() == 1 + + def test_next(self, cs_sync): + """ + Test sync version of CrossSync.next() + """ + it = iter([1, 2, 3]) + assert cs_sync.next(it) == 1 + assert cs_sync.next(it) == 2 + assert cs_sync.next(it) == 3 + with pytest.raises(StopIteration): + cs_sync.next(it) + with pytest.raises(cs_sync.StopIteration): + cs_sync.next(it) + + @pytest.mark.asyncio + async def test_next_async(self, cs_async): + """ + test async version of CrossSync.next() + """ + async_it = self.async_iter([1, 2, 3]) + assert await cs_async.next(async_it) == 1 + assert await cs_async.next(async_it) == 2 + assert await cs_async.next(async_it) == 3 + with pytest.raises(StopAsyncIteration): + await cs_async.next(async_it) + with pytest.raises(cs_async.StopIteration): + await cs_async.next(async_it) + + def test_gather_partials(self, cs_sync): + """ + Test sync version of CrossSync.gather_partials() + """ + with concurrent.futures.ThreadPoolExecutor() as e: + partials = [lambda i=i: i + 1 for i in range(5)] + results = cs_sync.gather_partials(partials, sync_executor=e) + assert results == [1, 2, 3, 4, 5] + + def test_gather_partials_with_excepptions(self, cs_sync): + """ + Test sync version of CrossSync.gather_partials() with exceptions + """ + with concurrent.futures.ThreadPoolExecutor() as e: + partials = [lambda i=i: i + 1 if i != 3 else 1/0 for i in range(5)] + with pytest.raises(ZeroDivisionError): + cs_sync.gather_partials(partials, sync_executor=e) + + def test_gather_partials_return_exceptions(self, cs_sync): + """ + Test sync version of CrossSync.gather_partials() with return_exceptions=True + """ + with concurrent.futures.ThreadPoolExecutor() as e: + partials = [lambda i=i: i + 1 if i != 3 else 1/0 for i in range(5)] + results = cs_sync.gather_partials(partials, return_exceptions=True, sync_executor=e) + assert len(results) == 5 + assert results[0] == 1 + assert results[1] == 2 + assert results[2] == 3 + assert isinstance(results[3], ZeroDivisionError) + assert results[4] == 5 + + def test_gather_partials_no_executor(self, cs_sync): + """ + Test sync version of CrossSync.gather_partials() without an executor + """ + partials = [lambda i=i: i + 1 for i in range(5)] + with pytest.raises(ValueError) as e: + results = cs_sync.gather_partials(partials) + assert "sync_executor is required" in str(e.value) + + @pytest.mark.asyncio + async def test_gather_partials_async(self, cs_async): + """ + Test async version of CrossSync.gather_partials() + """ + async def coro(i): + return i + 1 + + partials = [functools.partial(coro, i) for i in range(5)] + results = await cs_async.gather_partials(partials) + assert results == [1, 2, 3, 4, 5] + + @pytest.mark.asyncio + async def test_gather_partials_async_with_exceptions(self, cs_async): + """ + Test async version of CrossSync.gather_partials() with exceptions + """ + async def coro(i): + return i + 1 if i != 3 else 1/0 + + partials = [functools.partial(coro, i) for i in range(5)] + with pytest.raises(ZeroDivisionError): + await cs_async.gather_partials(partials) + + @pytest.mark.asyncio + async def test_gather_partials_async_return_exceptions(self, cs_async): + """ + Test async version of CrossSync.gather_partials() with return_exceptions=True + """ + async def coro(i): + return i + 1 if i != 3 else 1/0 + + partials = [functools.partial(coro, i) for i in range(5)] + results = await cs_async.gather_partials(partials, return_exceptions=True) + assert len(results) == 5 + assert results[0] == 1 + assert results[1] == 2 + assert results[2] == 3 + assert isinstance(results[3], ZeroDivisionError) + assert results[4] == 5 + + @pytest.mark.asyncio + async def test_gather_partials_async_uses_asyncio_gather(self, cs_async): + """ + CrossSync.gather_partials() should use asyncio.gather() internally + """ + async def coro(i): + return i + 1 + return_exceptions=object() + partials = [functools.partial(coro, i) for i in range(5)] + with mock.patch.object(asyncio, "gather", mock.AsyncMock()) as gather: + await cs_async.gather_partials(partials, return_exceptions=return_exceptions) + gather.assert_called_once() + found_args, found_kwargs = gather.call_args + assert found_kwargs["return_exceptions"] == return_exceptions + for coro in found_args: + await coro From 933a6263f83ab7cb63d8027eebb076861136433d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 10:46:41 -0700 Subject: [PATCH 233/360] moved wait into aliases --- .../data/_sync/cross_sync/cross_sync.py | 30 ++----------------- tests/unit/data/_sync/test_cross_sync.py | 1 + 2 files changed, 3 insertions(+), 28 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py index 1219c3b3d..80d5973d4 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py @@ -80,6 +80,7 @@ class CrossSync(metaclass=MappingMeta): # provide aliases for common async functions and types sleep = asyncio.sleep + wait = asyncio.wait retry_target = retries.retry_target_async retry_target_stream = retries.retry_target_stream_async Retry = retries.AsyncRetry @@ -146,20 +147,6 @@ async def gather_partials( *awaitable_list, return_exceptions=return_exceptions ) - @staticmethod - async def wait( - futures: Sequence[CrossSync.Future[T]], timeout: float | None = None - ) -> tuple[set[CrossSync.Future[T]], set[CrossSync.Future[T]]]: - """ - abstraction over asyncio.wait - - Return: - - a tuple of (done, pending) sets of futures - """ - if not futures: - return set(), set() - return await asyncio.wait(futures, timeout=timeout) - @staticmethod async def condition_wait( condition: CrossSync.Condition, timeout: float | None = None @@ -251,6 +238,7 @@ class _Sync_Impl(metaclass=MappingMeta): is_async = False sleep = time.sleep + wait = concurrent.futures.wait next = next retry_target = retries.retry_target retry_target_stream = retries.retry_target_stream @@ -277,20 +265,6 @@ def Mock(cls, *args, **kwargs): from mock import Mock # type: ignore return Mock(*args, **kwargs) - @staticmethod - def wait( - futures: Sequence[CrossSync._Sync_Impl.Future[T]], - timeout: float | None = None, - ) -> tuple[ - set[CrossSync._Sync_Impl.Future[T]], set[CrossSync._Sync_Impl.Future[T]] - ]: - """ - abstraction over asyncio.wait - """ - if not futures: - return set(), set() - return concurrent.futures.wait(futures, timeout=timeout) - @staticmethod def condition_wait( condition: CrossSync._Sync_Impl.Condition, timeout: float | None = None diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index 3f50b21f6..b8cd92f9a 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -30,6 +30,7 @@ def cs_async(self): "attr, async_version, sync_version", [ ("is_async", True, False), ("sleep", asyncio.sleep, time.sleep), + ("wait", asyncio.wait, concurrent.futures.wait), ("retry_target", api_core.retry.retry_target_async, api_core.retry.retry_target), ("retry_target_stream", api_core.retry.retry_target_stream_async, api_core.retry.retry_target_stream), ("Retry", api_core.retry.AsyncRetry, api_core.retry.Retry), From 8efe71d8929c7d54cf9e125163a212fe77f98a78 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 11:27:25 -0700 Subject: [PATCH 234/360] stripped out condition_wait --- .../data/_sync/cross_sync/cross_sync.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py index 80d5973d4..80bc523eb 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py @@ -147,21 +147,6 @@ async def gather_partials( *awaitable_list, return_exceptions=return_exceptions ) - @staticmethod - async def condition_wait( - condition: CrossSync.Condition, timeout: float | None = None - ) -> bool: - """ - abstraction over asyncio.Condition.wait - - returns False if the timeout is reached before the condition is set, otherwise True - """ - try: - await asyncio.wait_for(condition.wait(), timeout=timeout) - return True - except asyncio.TimeoutError: - return False - @staticmethod async def event_wait( event: CrossSync.Event, @@ -265,15 +250,6 @@ def Mock(cls, *args, **kwargs): from mock import Mock # type: ignore return Mock(*args, **kwargs) - @staticmethod - def condition_wait( - condition: CrossSync._Sync_Impl.Condition, timeout: float | None = None - ) -> bool: - """ - returns False if the timeout is reached before the condition is set, otherwise True - """ - return condition.wait(timeout=timeout) - @staticmethod def event_wait( event: CrossSync._Sync_Impl.Event, From 09623f45f8387aae99b60566817d553b42d7bbce Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 12:20:12 -0700 Subject: [PATCH 235/360] added tests for event_wait --- .../data/_sync/cross_sync/cross_sync.py | 3 +- tests/unit/data/_sync/test_cross_sync.py | 96 +++++++++++++++++++ 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py index 80bc523eb..7fb2a794b 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py @@ -166,7 +166,8 @@ async def event_wait( if timeout is None: await event.wait() elif not async_break_early: - await asyncio.sleep(timeout) + if not event.is_set(): + await asyncio.sleep(timeout) else: try: await asyncio.wait_for(event.wait(), timeout=timeout) diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index b8cd92f9a..90bb1842f 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -193,3 +193,99 @@ async def coro(i): assert found_kwargs["return_exceptions"] == return_exceptions for coro in found_args: await coro + + def test_event_wait_passthrough(self, cs_sync): + """ + Test sync version of CrossSync.event_wait() + should pass through timeout directly to the event.wait() call + """ + event = mock.Mock() + timeout = object() + cs_sync.event_wait(event, timeout) + event.wait.assert_called_once_with(timeout=timeout) + + @pytest.mark.parametrize("timeout", [0, 0.01, 0.05]) + def test_event_wait_timeout_exceeded(self, cs_sync, timeout): + """ + Test sync version of CrossSync.event_wait() + """ + event = threading.Event() + start_time = time.monotonic() + cs_sync.event_wait(event, timeout=timeout) + end_time = time.monotonic() + assert abs((end_time - start_time) - timeout) < 0.01 + + def test_event_wait_already_set(self, cs_sync): + """ + if event is already set, do not block + """ + event = threading.Event() + event.set() + start_time = time.monotonic() + cs_sync.event_wait(event, timeout=10) + end_time = time.monotonic() + assert end_time - start_time < 0.01 + + @pytest.mark.parametrize("break_early", [True, False]) + @pytest.mark.asyncio + async def test_event_wait_async(self, cs_async, break_early): + """ + With no timeout, call event.wait() with no arguments + """ + event = mock.AsyncMock() + await cs_async.event_wait(event, async_break_early=break_early) + event.wait.assert_called_once_with() + + + @pytest.mark.asyncio + async def test_event_wait_async_with_timeout(self, cs_async): + """ + In with timeout set, should call event.wait(), wrapped in wait_for() + for the timeout + """ + event = mock.Mock() + event.wait.return_value = object() + timeout = object() + with mock.patch.object(asyncio, "wait_for", mock.AsyncMock()) as wait_for: + await cs_async.event_wait(event, timeout=timeout) + assert wait_for.await_count == 1 + assert wait_for.call_count == 1 + wait_for.assert_called_once_with(event.wait(), timeout=timeout) + + @pytest.mark.asyncio + async def test_event_wait_async_timeout_exceeded(self, cs_async): + """ + If tiemout exceeded, break without throwing exception + """ + event = asyncio.Event() + timeout = 0.5 + start_time = time.monotonic() + await cs_async.event_wait(event, timeout=timeout) + end_time = time.monotonic() + assert abs((end_time - start_time) - timeout) < 0.01 + + @pytest.mark.parametrize("break_early", [True, False]) + @pytest.mark.asyncio + async def test_event_wait_async_already_set(self, cs_async, break_early): + """ + if event is already set, return immediately + """ + event = mock.AsyncMock() + event.is_set = lambda: True + start_time = time.monotonic() + await cs_async.event_wait(event, async_break_early=break_early) + end_time = time.monotonic() + assert abs(end_time - start_time) < 0.01 + + @pytest.mark.asyncio + async def test_event_wait_no_break_early(self, cs_async): + """ + if async_break_early is False, and the event is not set, + simply sleep for the timeout + """ + event = mock.Mock() + event.is_set.return_value = False + timeout = object() + with mock.patch.object(asyncio, "sleep", mock.AsyncMock()) as sleep: + await cs_async.event_wait(event, timeout=timeout, async_break_early=False) + sleep.assert_called_once_with(timeout) From 6fcbc89b98934fb694fab652ab58a93f1fb4b693 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 12:35:58 -0700 Subject: [PATCH 236/360] added tests for create_task --- tests/unit/data/_sync/test_cross_sync.py | 71 ++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index 90bb1842f..c6e95d6b4 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -289,3 +289,74 @@ async def test_event_wait_no_break_early(self, cs_async): with mock.patch.object(asyncio, "sleep", mock.AsyncMock()) as sleep: await cs_async.event_wait(event, timeout=timeout, async_break_early=False) sleep.assert_called_once_with(timeout) + + def test_create_task(self, cs_sync): + """ + Test creating Future using create_task() + """ + executor = concurrent.futures.ThreadPoolExecutor() + fn = lambda x, y: x + y + result = cs_sync.create_task(fn, 1, y=4, sync_executor=executor) + assert isinstance(result, cs_sync.Task) + assert result.result() == 5 + + def test_create_task_passthrough(self, cs_sync): + """ + sync version passed through to executor.submit() + """ + fn = object() + executor = mock.Mock() + executor.submit.return_value = object() + args = [1, 2, 3] + kwargs = {"a": 1, "b": 2} + result = cs_sync.create_task(fn, *args, **kwargs, sync_executor=executor) + assert result == executor.submit.return_value + assert executor.submit.call_count == 1 + assert executor.submit.call_args == ((fn, *args), kwargs) + + + def test_create_task_no_executor(self, cs_sync): + """ + if no executor is provided, raise an exception + """ + with pytest.raises(ValueError) as e: + cs_sync.create_task(lambda: None) + assert "sync_executor is required" in str(e.value) + + @pytest.mark.asyncio + async def test_create_task_async(self, cs_async): + """ + Test creating Future using create_task() + """ + async def coro_fn(x, y): + return x + y + result = cs_async.create_task(coro_fn, 1, y=4) + assert isinstance(result, asyncio.Task) + assert await result == 5 + + @pytest.mark.asyncio + async def test_create_task_async_passthrough(self, cs_async): + """ + async version passed through to asyncio.create_task() + """ + coro_fn = mock.Mock() + coro_fn.return_value = object() + args = [1, 2, 3] + kwargs = {"a": 1, "b": 2} + with mock.patch.object(asyncio, "create_task", mock.Mock()) as create_task: + result = cs_async.create_task(coro_fn, *args, **kwargs) + create_task.assert_called_once() + create_task.assert_called_once_with(coro_fn.return_value) + coro_fn.assert_called_once_with(*args, **kwargs) + + @pytest.mark.asyncio + async def test_create_task_async_with_name(self, cs_async): + """ + Test creating a task with a name + """ + async def coro_fn(): + return None + name = "test-name-456" + result = cs_async.create_task(coro_fn, task_name=name) + assert isinstance(result, asyncio.Task) + assert result.get_name() == name From 9044c4a187869e6944b906cec60ed5edcbc57931 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 13:07:50 -0700 Subject: [PATCH 237/360] added tests for remaining functions --- tests/unit/data/_sync/test_cross_sync.py | 47 ++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index c6e95d6b4..40bbaa2c8 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -360,3 +360,50 @@ async def coro_fn(): result = cs_async.create_task(coro_fn, task_name=name) assert isinstance(result, asyncio.Task) assert result.get_name() == name + + def test_yeild_to_event_loop(self, cs_sync): + """ + no-op in sync version + """ + assert cs_sync.yield_to_event_loop() is None + + @pytest.mark.asyncio + async def test_yield_to_event_loop_async(self, cs_async): + """ + should call await asyncio.sleep(0) + """ + with mock.patch.object(asyncio, "sleep", mock.AsyncMock()) as sleep: + await cs_async.yield_to_event_loop() + sleep.assert_called_once_with(0) + + def test_verify_async_event_loop(self, cs_sync): + """ + no-op in sync version + """ + assert cs_sync.verify_async_event_loop() is None + + @pytest.mark.asyncio + async def test_verify_async_event_loop_async(self, cs_async): + """ + should call asyncio.get_running_loop() + """ + with mock.patch.object(asyncio, "get_running_loop") as get_running_loop: + cs_async.verify_async_event_loop() + get_running_loop.assert_called_once() + + def test_verify_async_event_loop_no_event_loop(self, cs_async): + """ + Should raise an exception if no event loop is running + """ + with pytest.raises(RuntimeError) as e: + cs_async.verify_async_event_loop() + assert "no running event loop" in str(e.value) + + def test_rmaio(self, cs_async): + """ + rm_aio should return whatever is passed to it + """ + assert cs_async.rm_aio(1) == 1 + assert cs_async.rm_aio("test") == "test" + obj = object() + assert cs_async.rm_aio(obj) == obj From ee11e049f79a318e809aa30cf1364227c5bff381 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 13:53:38 -0700 Subject: [PATCH 238/360] added test outline for decorators --- .../data/_sync/test_cross_sync_decorators.py | 125 ++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 tests/unit/data/_sync/test_cross_sync_decorators.py diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_sync/test_cross_sync_decorators.py new file mode 100644 index 000000000..f210f97c8 --- /dev/null +++ b/tests/unit/data/_sync/test_cross_sync_decorators.py @@ -0,0 +1,125 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest import mock +from google.cloud.bigtable.data._sync.cross_sync.cross_sync import CrossSync +from google.cloud.bigtable.data._sync.cross_sync._decorators import ExportSync, Convert, DropMethod, Pytest, PytestFixture + + +class TestExportSyncDecorator: + + def _get_class(self): + return ExportSync + + def test_class_decorator(self): + """ + Should return class being decorated + """ + unwrapped_class = mock.Mock + wrapped_class = self._get_class().decorator(unwrapped_class, path=1) + assert unwrapped_class == wrapped_class + + def test_class_decorator_adds_mapping(self): + """ + If add_mapping_for_name is set, should call CrossSync.add_mapping with the class being decorated + """ + with mock.patch.object(CrossSync, "add_mapping") as add_mapping: + mock_cls = mock.Mock + # check decoration with no add_mapping + self._get_class().decorator(path=1)(mock_cls) + assert add_mapping.call_count == 0 + # check decoration with add_mapping + name = "test_name" + self._get_class().decorator(path=1, add_mapping_for_name=name)(mock_cls) + assert add_mapping.call_count == 1 + add_mapping.assert_called_once_with(name, mock_cls) + + @pytest.mark.parametrize("docstring,format_vars,expected", [ + ["test docstring", {}, "test docstring"], + ["{}", {}, "{}"], + ["test_docstring", {"A": (1, 2)}, "test_docstring"], + ["{A}", {"A": (1, 2)}, "1"], + ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "1 3"], + ["hello {world_var}", {"world_var": ("world", "moon")}, "hello world"], + ]) + def test_class_decorator_docstring_update(self, docstring, format_vars, expected): + """ + If docstring_format_vars is set, should update the docstring + of the class being decorated + """ + @self._get_class().decorator(path=1, docstring_format_vars=format_vars) + class Class: + __doc__ = docstring + assert Class.__doc__ == expected + # check internal state + instance = self._get_class()(path=1, docstring_format_vars=format_vars) + async_replacements = {k: v[0] for k, v in format_vars.items()} + sync_replacements = {k: v[1] for k, v in format_vars.items()} + assert instance.async_docstring_format_vars == async_replacements + assert instance.sync_docstring_format_vars == sync_replacements + + def test_sync_ast_transform(self): + pass + +class TestConvertDecorator: + + def _get_class(self): + return Convert + + def test_decorator_functionality(self): + pass + + def test_sync_ast_transform(self): + pass + +class TestDropMethodDecorator: + + def _get_class(self): + return DropMethod + + def test_decorator_functionality(self): + """ + applying the decorator should be a no-op + """ + unwrapped = lambda x: x + wrapped = self._get_class().decorator(unwrapped) + assert unwrapped == wrapped + assert unwrapped(1) == wrapped(1) + assert wrapped(1) == 1 + + def test_sync_ast_transform(self): + pass + +class TestPytestDecorator: + + def _get_class(self): + return Pytest + + def test_decorator_functionality(self): + pass + + def test_sync_ast_transform(self): + pass + +class TestPytestFixtureDecorator: + + def _get_class(self): + return PytestFixture + + def test_decorator_functionality(self): + pass + + def test_sync_ast_transform(self): + pass From 89c2abeb66c62ca86d3393178c67e5a196d481c9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 14:37:49 -0700 Subject: [PATCH 239/360] added tests for ExportSync --- .../data/_sync/cross_sync/_decorators.py | 3 + .../data/_sync/test_cross_sync_decorators.py | 160 +++++++++++++++++- 2 files changed, 161 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index bf4d855de..02d9d6158 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -240,6 +240,8 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): wrapped_node.decorator_list = [ d for d in wrapped_node.decorator_list if "CrossSync" not in ast.dump(d) ] + else: + wrapped_node.decorator_list = [] # add mapping decorator if needed if self.add_mapping_for_name: wrapped_node.decorator_list.append( @@ -266,6 +268,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): wrapped_node = transformers_globals["CrossSyncMethodDecoratorHandler"]().visit( wrapped_node ) + # update docstring if specified if self.sync_docstring_format_vars: docstring = ast.get_docstring(wrapped_node) wrapped_node.body[0].value.s = docstring.format(**self.sync_docstring_format_vars) diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_sync/test_cross_sync_decorators.py index f210f97c8..f9aff87d8 100644 --- a/tests/unit/data/_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_sync/test_cross_sync_decorators.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +import ast from unittest import mock from google.cloud.bigtable.data._sync.cross_sync.cross_sync import CrossSync from google.cloud.bigtable.data._sync.cross_sync._decorators import ExportSync, Convert, DropMethod, Pytest, PytestFixture @@ -23,6 +24,54 @@ class TestExportSyncDecorator: def _get_class(self): return ExportSync + @pytest.fixture + def globals_mock(self): + mock_transform = mock.Mock() + mock_transform().visit = lambda x: x + global_dict = {k: mock_transform for k in ["RmAioFunctions", "SymbolReplacer", "CrossSyncMethodDecoratorHandler"]} + return global_dict + + def test_ctor_defaults(self): + """ + Should set default values for path, add_mapping_for_name, and docstring_format_vars + """ + with pytest.raises(TypeError) as exc: + self._get_class()() + assert "missing 1 required positional argument" in str(exc.value) + path = object() + instance = self._get_class()(path) + assert instance.path is path + assert instance.replace_symbols is None + assert instance.mypy_ignore is () + assert instance.include_file_imports is True + assert instance.add_mapping_for_name is None + assert instance.async_docstring_format_vars == {} + assert instance.sync_docstring_format_vars == {} + + def test_ctor(self): + path = object() + replace_symbols = {"a": "b"} + docstring_format_vars = {"A": (1, 2)} + mypy_ignore = ("a", "b") + include_file_imports = False + add_mapping_for_name = "test_name" + + instance = self._get_class()( + path=path, + replace_symbols=replace_symbols, + docstring_format_vars=docstring_format_vars, + mypy_ignore=mypy_ignore, + include_file_imports=include_file_imports, + add_mapping_for_name=add_mapping_for_name + ) + assert instance.path is path + assert instance.replace_symbols is replace_symbols + assert instance.mypy_ignore is mypy_ignore + assert instance.include_file_imports is include_file_imports + assert instance.add_mapping_for_name is add_mapping_for_name + assert instance.async_docstring_format_vars == {"A": 1} + assert instance.sync_docstring_format_vars == {"A": 2} + def test_class_decorator(self): """ Should return class being decorated @@ -70,8 +119,115 @@ class Class: assert instance.async_docstring_format_vars == async_replacements assert instance.sync_docstring_format_vars == sync_replacements - def test_sync_ast_transform(self): - pass + def test_sync_ast_transform_replaces_name(self, globals_mock): + """ + Should update the name of the new class + """ + decorator = self._get_class()("path.to.SyncClass") + mock_node = ast.ClassDef(name="AsyncClass", bases=[], keywords=[], body=[]) + + result = decorator.sync_ast_transform(mock_node, globals_mock) + + assert isinstance(result, ast.ClassDef) + assert result.name == "SyncClass" + + def test_sync_ast_transform_strips_cross_sync_decorators(self, globals_mock): + """ + should remove all CrossSync decorators from the class + """ + decorator = self._get_class()("path") + cross_sync_decorator = ast.Call(func=ast.Attribute(value=ast.Name(id='CrossSync', ctx=ast.Load()), attr='some_decorator', ctx=ast.Load()), args=[], keywords=[]) + other_decorator = ast.Name(id='other_decorator', ctx=ast.Load()) + mock_node = ast.ClassDef(name="AsyncClass", bases=[], keywords=[], body=[], decorator_list=[cross_sync_decorator, other_decorator]) + + result = decorator.sync_ast_transform(mock_node, globals_mock) + + assert isinstance(result, ast.ClassDef) + assert len(result.decorator_list) == 1 + assert isinstance(result.decorator_list[0], ast.Name) + assert result.decorator_list[0].id == 'other_decorator' + + def test_sync_ast_transform_add_mapping(self, globals_mock): + """ + If add_mapping_for_name is set, should add CrossSync.add_mapping_decorator to new class + """ + decorator = self._get_class()("path", add_mapping_for_name="sync_class") + mock_node = ast.ClassDef(name="AsyncClass", bases=[], keywords=[], body=[]) + + result = decorator.sync_ast_transform(mock_node, globals_mock) + + assert isinstance(result, ast.ClassDef) + assert len(result.decorator_list) == 1 + assert isinstance(result.decorator_list[0], ast.Call) + assert isinstance(result.decorator_list[0].func, ast.Attribute) + assert result.decorator_list[0].func.attr == 'add_mapping_decorator' + assert result.decorator_list[0].args[0].value == 'sync_class' + + @pytest.mark.parametrize("docstring,format_vars,expected", [ + ["test docstring", {}, "test docstring"], + ["{}", {}, "{}"], + ["test_docstring", {"A": (1, 2)}, "test_docstring"], + ["{A}", {"A": (1, 2)}, "2"], + ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "2 4"], + ["hello {world_var}", {"world_var": ("world", "moon")}, "hello moon"], + ]) + def test_sync_ast_transform_add_docstring_format(self, docstring, format_vars, expected, globals_mock): + """ + If docstring_format_vars is set, should format the docstring of the new class + """ + decorator = self._get_class()("path.to.SyncClass", docstring_format_vars=format_vars) + mock_node = ast.ClassDef( + name="AsyncClass", + bases=[], + keywords=[], + body=[ast.Expr(value=ast.Constant(value=docstring))] + ) + result = decorator.sync_ast_transform(mock_node, globals_mock) + + assert isinstance(result, ast.ClassDef) + assert isinstance(result.body[0], ast.Expr) + assert isinstance(result.body[0].value, ast.Constant) + assert result.body[0].value.value == expected + + def test_sync_ast_transform_call_cross_sync_transforms(self): + """ + Should use transformers_globals to call some extra transforms on class: + - RmAioFunctions + - SymbolReplacer + - CrossSyncMethodDecoratorHandler + """ + decorator = self._get_class()("path.to.SyncClass") + mock_node = ast.ClassDef(name="AsyncClass", bases=[], keywords=[], body=[]) + + transformers_globals = { + "RmAioFunctions": mock.Mock(), + "SymbolReplacer": mock.Mock(), + "CrossSyncMethodDecoratorHandler": mock.Mock(), + } + decorator.sync_ast_transform(mock_node, transformers_globals) + # ensure each transformer was called + for transformer in transformers_globals.values(): + assert transformer.call_count == 1 + + def test_sync_ast_transform_replace_symbols(self, globals_mock): + """ + SymbolReplacer should be called with replace_symbols + """ + replace_symbols = {"a": "b", "c": "d"} + decorator = self._get_class()("path.to.SyncClass", replace_symbols=replace_symbols) + mock_node = ast.ClassDef(name="AsyncClass", bases=[], keywords=[], body=[]) + symbol_transform_mock = mock.Mock() + globals_mock = {**globals_mock, "SymbolReplacer": symbol_transform_mock} + decorator.sync_ast_transform(mock_node, globals_mock) + # make sure SymbolReplacer was called with replace_symbols + assert symbol_transform_mock.call_count == 1 + found_dict = symbol_transform_mock.call_args[0][0] + assert "a" in found_dict + for k, v in replace_symbols.items(): + assert found_dict[k] == v + # should also add CrossSync replacement + assert found_dict["CrossSync"] == "CrossSync._Sync_Impl" + class TestConvertDecorator: From 7925b24ea5d8fd031d908016107a34af9694fa15 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 14:56:25 -0700 Subject: [PATCH 240/360] added tests for convert --- .../data/_sync/cross_sync/_decorators.py | 2 +- .../data/_sync/test_cross_sync_decorators.py | 155 ++++++++++++++++-- 2 files changed, 145 insertions(+), 12 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index 02d9d6158..751f6c882 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -315,7 +315,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): wrapped_node.name, wrapped_node.args, wrapped_node.body, - wrapped_node.decorator_list, + wrapped_node.decorator_list if hasattr(wrapped_node, "decorator_list") else [], wrapped_node.returns, ), wrapped_node, diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_sync/test_cross_sync_decorators.py index f9aff87d8..441fb7dbe 100644 --- a/tests/unit/data/_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_sync/test_cross_sync_decorators.py @@ -19,18 +19,19 @@ from google.cloud.bigtable.data._sync.cross_sync._decorators import ExportSync, Convert, DropMethod, Pytest, PytestFixture +@pytest.fixture +def globals_mock(): + mock_transform = mock.Mock() + mock_transform().visit = lambda x: x + global_dict = {k: mock_transform for k in ["RmAioFunctions", "SymbolReplacer", "CrossSyncMethodDecoratorHandler"]} + return global_dict + + class TestExportSyncDecorator: def _get_class(self): return ExportSync - @pytest.fixture - def globals_mock(self): - mock_transform = mock.Mock() - mock_transform().visit = lambda x: x - global_dict = {k: mock_transform for k in ["RmAioFunctions", "SymbolReplacer", "CrossSyncMethodDecoratorHandler"]} - return global_dict - def test_ctor_defaults(self): """ Should set default values for path, add_mapping_for_name, and docstring_format_vars @@ -234,11 +235,143 @@ class TestConvertDecorator: def _get_class(self): return Convert - def test_decorator_functionality(self): - pass + def test_ctor_defaults(self): + instance = self._get_class()() + assert instance.sync_name is None + assert instance.replace_symbols is None + assert instance.async_docstring_format_vars == {} + assert instance.sync_docstring_format_vars == {} + assert instance.rm_aio is False + + def test_ctor(self): + sync_name = "sync_name" + replace_symbols = {"a": "b"} + docstring_format_vars = {"A": (1, 2)} + rm_aio = True + + instance = self._get_class()( + sync_name=sync_name, + replace_symbols=replace_symbols, + docstring_format_vars=docstring_format_vars, + rm_aio=rm_aio + ) + assert instance.sync_name is sync_name + assert instance.replace_symbols is replace_symbols + assert instance.async_docstring_format_vars == {"A": 1} + assert instance.sync_docstring_format_vars == {"A": 2} + assert instance.rm_aio is rm_aio + + def test_async_decorator_no_docstring(self): + """ + If no docstring_format_vars is set, should be a no-op + """ + unwrapped_class = mock.Mock + wrapped_class = self._get_class().decorator(unwrapped_class) + assert unwrapped_class == wrapped_class + + @pytest.mark.parametrize("docstring,format_vars,expected", [ + ["test docstring", {}, "test docstring"], + ["{}", {}, "{}"], + ["test_docstring", {"A": (1, 2)}, "test_docstring"], + ["{A}", {"A": (1, 2)}, "1"], + ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "1 3"], + ["hello {world_var}", {"world_var": ("world", "moon")}, "hello world"], + ]) + def test_async_decorator_docstring_update(self, docstring, format_vars, expected): + """ + If docstring_format_vars is set, should update the docstring + of the class being decorated + """ + @self._get_class().decorator(docstring_format_vars=format_vars) + class Class: + __doc__ = docstring + assert Class.__doc__ == expected + # check internal state + instance = self._get_class()(docstring_format_vars=format_vars) + async_replacements = {k: v[0] for k, v in format_vars.items()} + sync_replacements = {k: v[1] for k, v in format_vars.items()} + assert instance.async_docstring_format_vars == async_replacements + assert instance.sync_docstring_format_vars == sync_replacements + + def test_sync_ast_transform_remove_adef(self): + """ + Should convert `async def` methods to `def` methods + """ + decorator = self._get_class()() + mock_node = ast.AsyncFunctionDef(name="test_method", args=ast.arguments(), body=[]) + + result = decorator.sync_ast_transform(mock_node, {}) + + assert isinstance(result, ast.FunctionDef) + assert result.name == "test_method" + + def test_sync_ast_transform_replaces_name(self, globals_mock): + """ + Should update the name of the method if sync_name is set + """ + decorator = self._get_class()(sync_name="new_method_name") + mock_node = ast.AsyncFunctionDef(name="old_method_name", args=ast.arguments(), body=[]) + + result = decorator.sync_ast_transform(mock_node, globals_mock) + + assert isinstance(result, ast.FunctionDef) + assert result.name == "new_method_name" + + def test_sync_ast_transform_calls_async_to_sync(self): + """ + Should call AsyncToSync if rm_aio is set + """ + decorator = self._get_class()(rm_aio=True) + mock_node = ast.AsyncFunctionDef(name="test_method", args=ast.arguments(), body=[]) + async_to_sync_mock = mock.Mock() + async_to_sync_mock.visit.return_value = mock_node + globals_mock = {"AsyncToSync": lambda: async_to_sync_mock} + + decorator.sync_ast_transform(mock_node, globals_mock) + assert async_to_sync_mock.visit.call_count == 1 + + def test_sync_ast_transform_replace_symbols(self): + """ + Should call SymbolReplacer with replace_symbols if replace_symbols is set + """ + replace_symbols = {"old_symbol": "new_symbol"} + decorator = self._get_class()(replace_symbols=replace_symbols) + mock_node = ast.AsyncFunctionDef(name="test_method", args=ast.arguments(), body=[]) + symbol_replacer_mock = mock.Mock() + globals_mock = {"SymbolReplacer": symbol_replacer_mock} + + decorator.sync_ast_transform(mock_node, globals_mock) + + assert symbol_replacer_mock.call_count == 1 + assert symbol_replacer_mock.call_args[0][0] == replace_symbols + assert symbol_replacer_mock(replace_symbols).visit.call_count == 1 + + @pytest.mark.parametrize("docstring,format_vars,expected", [ + ["test docstring", {}, "test docstring"], + ["{}", {}, "{}"], + ["test_docstring", {"A": (1, 2)}, "test_docstring"], + ["{A}", {"A": (1, 2)}, "2"], + ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "2 4"], + ["hello {world_var}", {"world_var": ("world", "moon")}, "hello moon"], + ]) + def test_sync_ast_transform_add_docstring_format(self, docstring, format_vars, expected): + """ + If docstring_format_vars is set, should format the docstring of the new method + """ + decorator = self._get_class()(docstring_format_vars=format_vars) + mock_node = ast.AsyncFunctionDef( + name="test_method", + args=ast.arguments(), + body=[ast.Expr(value=ast.Constant(value=docstring))] + ) + + result = decorator.sync_ast_transform(mock_node, {}) + + assert isinstance(result, ast.FunctionDef) + assert isinstance(result.body[0], ast.Expr) + assert isinstance(result.body[0].value, ast.Constant) + assert result.body[0].value.value == expected - def test_sync_ast_transform(self): - pass class TestDropMethodDecorator: From de32f7fcb323ebd32b25671152fe38032c5c8e99 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 15:27:36 -0700 Subject: [PATCH 241/360] added remaining decorator tests --- .../data/_sync/cross_sync/_decorators.py | 2 + .../data/_sync/test_cross_sync_decorators.py | 81 +++++++++++++++++-- 2 files changed, 78 insertions(+), 5 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index 751f6c882..a447adfcf 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -414,6 +414,8 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): import copy new_node = copy.deepcopy(wrapped_node) + if not hasattr(new_node, "decorator_list"): + new_node.decorator_list = [] new_node.decorator_list.append( ast.Call( func=ast.Attribute( diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_sync/test_cross_sync_decorators.py index 441fb7dbe..8a50b6f7c 100644 --- a/tests/unit/data/_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_sync/test_cross_sync_decorators.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +import pytest_asyncio import ast from unittest import mock from google.cloud.bigtable.data._sync.cross_sync.cross_sync import CrossSync @@ -389,18 +390,61 @@ def test_decorator_functionality(self): assert wrapped(1) == 1 def test_sync_ast_transform(self): - pass + """ + Should return None for any input method + """ + decorator = self._get_class()() + mock_node = ast.AsyncFunctionDef(name="test_method", args=ast.arguments(), body=[]) + + result = decorator.sync_ast_transform(mock_node, {}) + + assert result is None + class TestPytestDecorator: def _get_class(self): return Pytest + def test_ctor(self): + instance = self._get_class()() + assert instance.rm_aio is True + instance = self._get_class()(rm_aio=False) + assert instance.rm_aio is False + def test_decorator_functionality(self): - pass + """ + Should wrap the class with pytest.mark.asyncio + """ + unwrapped_class = mock.Mock + wrapped_class = self._get_class().decorator(unwrapped_class) + assert wrapped_class == pytest.mark.asyncio(unwrapped_class) def test_sync_ast_transform(self): - pass + """ + Should be no-op if rm_aio is not set + """ + decorator = self._get_class()(rm_aio=False) + + input_obj = object() + result = decorator.sync_ast_transform(input_obj, {}) + + assert result is input_obj + + def test_sync_ast_transform_rm_aio(self): + """ + If rm_aio is set, should call AsyncToSync on the class + """ + decorator = self._get_class()() + mock_node = ast.ClassDef(name="AsyncClass", bases=[], keywords=[], body=[]) + + async_to_sync_mock = mock.Mock() + async_to_sync_mock.visit.return_value = mock_node + globals_mock = {"AsyncToSync": lambda: async_to_sync_mock} + + decorator.sync_ast_transform(mock_node, globals_mock) + assert async_to_sync_mock.visit.call_count == 1 + class TestPytestFixtureDecorator: @@ -408,7 +452,34 @@ def _get_class(self): return PytestFixture def test_decorator_functionality(self): - pass + """ + Should wrap the class with pytest_asyncio.fixture + """ + with mock.patch.object(pytest_asyncio, "fixture") as fixture: + @self._get_class().decorator(1, 2, scope="function", params=[3, 4]) + def fn(): + pass + + assert fixture.call_count == 1 + assert fixture.call_args[0] == (1, 2) + assert fixture.call_args[1] == {"scope": "function", "params": [3, 4]} def test_sync_ast_transform(self): - pass + """ + Should attach pytest.fixture to generated method + """ + decorator = self._get_class()(1, 2, scope="function") + + mock_node = ast.AsyncFunctionDef(name="test_method", args=ast.arguments(), body=[]) + + result = decorator.sync_ast_transform(mock_node, {}) + + assert isinstance(result, ast.AsyncFunctionDef) + assert len(result.decorator_list) == 1 + assert isinstance(result.decorator_list[0], ast.Call) + assert result.decorator_list[0].func.value.id == "pytest" + assert result.decorator_list[0].func.attr == "fixture" + assert result.decorator_list[0].args[0].value == 1 + assert result.decorator_list[0].args[1].value == 2 + assert result.decorator_list[0].keywords[0].arg == "scope" + assert result.decorator_list[0].keywords[0].value.value == "function" From 3c0f1dea02d86e61e46260e24d9a1cb870115035 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 15:35:20 -0700 Subject: [PATCH 242/360] added mapping tests --- tests/unit/data/_sync/test_cross_sync.py | 42 ++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index 40bbaa2c8..bd9fe524e 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -1,3 +1,16 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import typing import asyncio import pytest @@ -407,3 +420,32 @@ def test_rmaio(self, cs_async): assert cs_async.rm_aio("test") == "test" obj = object() assert cs_async.rm_aio(obj) == obj + + def test_add_mapping(self, cs_sync, cs_async): + """ + Add dynamic attributes to each class using add_mapping() + """ + for cls in [cs_sync, cs_async]: + cls.add_mapping("test", 1) + assert cls.test == 1 + assert cls._runtime_replacements[(cls, "test")] == 1 + + def test_add_duplicate_mapping(self, cs_sync, cs_async): + """ + Adding the same attribute twice should raise an exception + """ + for cls in [cs_sync, cs_async]: + cls.add_mapping("duplicate", 1) + with pytest.raises(AttributeError) as e: + cls.add_mapping("duplicate", 2) + assert "Conflicting assignments" in str(e.value) + + def test_add_mapping_decorator(self, cs_sync, cs_async): + """ + add_mapping_decorator should allow wrapping classes with add_mapping() + """ + for cls in [cs_sync, cs_async]: + @cls.add_mapping_decorator("decorated") + class Decorated: + pass + assert cls.decorated == Decorated From da38ac4332437735c949fea5af7f56511a13fc03 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 16:27:00 -0700 Subject: [PATCH 243/360] added e2e test structure for cross_sync --- .cross_sync/test_cases/async_to_sync.py | 13 ++++++++++ .cross_sync/test_cross_sync_e2e.py | 34 +++++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 .cross_sync/test_cases/async_to_sync.py create mode 100644 .cross_sync/test_cross_sync_e2e.py diff --git a/.cross_sync/test_cases/async_to_sync.py b/.cross_sync/test_cases/async_to_sync.py new file mode 100644 index 000000000..67d2b1f42 --- /dev/null +++ b/.cross_sync/test_cases/async_to_sync.py @@ -0,0 +1,13 @@ +tests: + - description: "async for loop fn" + before: | + async def func_name(): + async for i in range(10): + await routine() + return 42 + transformers: [AsyncToSync] + after: | + def func_name(): + for i in range(10): + routine() + return 42 diff --git a/.cross_sync/test_cross_sync_e2e.py b/.cross_sync/test_cross_sync_e2e.py new file mode 100644 index 000000000..91bcf8335 --- /dev/null +++ b/.cross_sync/test_cross_sync_e2e.py @@ -0,0 +1,34 @@ +import ast +import sys +import os +import black +import pytest +# add cross_sync to path +sys.path.append("google/cloud/bigtable/data/_sync/cross_sync") +from transformers import SymbolReplacer, AsyncToSync, RmAioFunctions, CrossSyncClassDecoratorHandler, CrossSyncClassDecoratorHandler + + +def loader(): + dir_name = os.path.join(os.path.dirname(__file__), "test_cases") + for file_name in os.listdir(dir_name): + test_case_file = os.path.join(dir_name, file_name) + # load test cases + import yaml + with open(test_case_file) as f: + test_cases = yaml.safe_load(f) + for test in test_cases["tests"]: + test["file_name"] = file_name + yield test + +@pytest.mark.parametrize( + "test_dict", loader(), ids=lambda x: f"{x['file_name']}: {x.get('description', '')}" +) +def test_e2e_scenario(test_dict): + before_ast = ast.parse(test_dict["before"]).body[0] + transformers = [globals()[t] for t in test_dict["transformers"]] + got_ast = before_ast + for transformer in transformers: + got_ast = transformer().visit(got_ast) + final_str = black.format_str(ast.unparse(got_ast), mode=black.FileMode()) + expected_str = black.format_str(test_dict["after"], mode=black.FileMode()) + assert final_str == expected_str, f"Expected:\n{expected_str}\nGot:\n{final_str}" From 169255ba8dcf6e1fbc6b38f29e8a53b20c03e714 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 16:27:17 -0700 Subject: [PATCH 244/360] fixed failing nox tests --- google/cloud/bigtable/data/_sync/cross_sync/_decorators.py | 2 +- tests/unit/data/_sync/test_cross_sync_decorators.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index a447adfcf..6c94b0913 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -316,7 +316,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): wrapped_node.args, wrapped_node.body, wrapped_node.decorator_list if hasattr(wrapped_node, "decorator_list") else [], - wrapped_node.returns, + wrapped_node.returns if hasattr(wrapped_node, "returns") else None, ), wrapped_node, ) diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_sync/test_cross_sync_decorators.py index 8a50b6f7c..d31a1d0ed 100644 --- a/tests/unit/data/_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_sync/test_cross_sync_decorators.py @@ -110,7 +110,7 @@ def test_class_decorator_docstring_update(self, docstring, format_vars, expected If docstring_format_vars is set, should update the docstring of the class being decorated """ - @self._get_class().decorator(path=1, docstring_format_vars=format_vars) + @ExportSync.decorator(path=1, docstring_format_vars=format_vars) class Class: __doc__ = docstring assert Class.__doc__ == expected @@ -283,7 +283,7 @@ def test_async_decorator_docstring_update(self, docstring, format_vars, expected If docstring_format_vars is set, should update the docstring of the class being decorated """ - @self._get_class().decorator(docstring_format_vars=format_vars) + @Convert.decorator(docstring_format_vars=format_vars) class Class: __doc__ = docstring assert Class.__doc__ == expected @@ -456,7 +456,7 @@ def test_decorator_functionality(self): Should wrap the class with pytest_asyncio.fixture """ with mock.patch.object(pytest_asyncio, "fixture") as fixture: - @self._get_class().decorator(1, 2, scope="function", params=[3, 4]) + @PytestFixture.decorator(1, 2, scope="function", params=[3, 4]) def fn(): pass From 5eddd03042d07d985c753e6afef92d069d884037 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 16:43:24 -0700 Subject: [PATCH 245/360] added test cases --- .cross_sync/test_cases/async_to_sync.py | 13 ---- .cross_sync/test_cases/async_to_sync.yaml | 76 +++++++++++++++++++++++ .cross_sync/test_cases/rm_aio.yaml | 63 +++++++++++++++++++ .cross_sync/test_cross_sync_e2e.py | 6 +- 4 files changed, 144 insertions(+), 14 deletions(-) delete mode 100644 .cross_sync/test_cases/async_to_sync.py create mode 100644 .cross_sync/test_cases/async_to_sync.yaml create mode 100644 .cross_sync/test_cases/rm_aio.yaml diff --git a/.cross_sync/test_cases/async_to_sync.py b/.cross_sync/test_cases/async_to_sync.py deleted file mode 100644 index 67d2b1f42..000000000 --- a/.cross_sync/test_cases/async_to_sync.py +++ /dev/null @@ -1,13 +0,0 @@ -tests: - - description: "async for loop fn" - before: | - async def func_name(): - async for i in range(10): - await routine() - return 42 - transformers: [AsyncToSync] - after: | - def func_name(): - for i in range(10): - routine() - return 42 diff --git a/.cross_sync/test_cases/async_to_sync.yaml b/.cross_sync/test_cases/async_to_sync.yaml new file mode 100644 index 000000000..99d39cbc5 --- /dev/null +++ b/.cross_sync/test_cases/async_to_sync.yaml @@ -0,0 +1,76 @@ +tests: + - description: "async for loop fn" + before: | + async def func_name(): + async for i in range(10): + await routine() + return 42 + transformers: [AsyncToSync] + after: | + def func_name(): + for i in range(10): + routine() + return 42 + + - description: "async with statement" + before: | + async def func_name(): + async with context_manager() as cm: + await do_something(cm) + transformers: [AsyncToSync] + after: | + def func_name(): + with context_manager() as cm: + do_something(cm) + + - description: "async function definition" + before: | + async def async_function(param1, param2): + result = await some_coroutine() + return result + transformers: [AsyncToSync] + after: | + def async_function(param1, param2): + result = some_coroutine() + return result + + - description: "list comprehension with async for" + before: | + async def func_name(): + result = [x async for x in aiter() if await predicate(x)] + transformers: [AsyncToSync] + after: | + def func_name(): + result = [x for x in aiter() if predicate(x)] + + - description: "multiple async features in one function" + before: | + async def complex_function(): + async with resource_manager() as res: + async for item in res.items(): + if await check(item): + yield await process(item) + transformers: [AsyncToSync] + after: | + def complex_function(): + with resource_manager() as res: + for item in res.items(): + if check(item): + yield process(item) + + - description: "nested async constructs" + before: | + async def nested_async(): + async with outer_context(): + async for x in outer_iter(): + async with inner_context(x): + async for y in inner_iter(x): + await process(x, y) + transformers: [AsyncToSync] + after: | + def nested_async(): + with outer_context(): + for x in outer_iter(): + with inner_context(x): + for y in inner_iter(x): + process(x, y) diff --git a/.cross_sync/test_cases/rm_aio.yaml b/.cross_sync/test_cases/rm_aio.yaml new file mode 100644 index 000000000..a6434dae8 --- /dev/null +++ b/.cross_sync/test_cases/rm_aio.yaml @@ -0,0 +1,63 @@ +tests: + - description: "remove await" + before: | + CrossSync.rm_aio(await routine()) + transformers: [RmAioFunctions] + after: | + routine() + - description: "async for loop fn" + before: | + async def func_name(): + async for i in CrossSync.rm_aio(range(10)): + await routine() + return 42 + transformers: [RmAioFunctions] + after: | + async def func_name(): + for i in range(10): + await routine() + return 42 + + - description: "async with statement" + before: | + async def func_name(): + async with CrossSync.rm_aio(context_manager()) as cm: + await do_something(cm) + transformers: [RmAioFunctions] + after: | + async def func_name(): + with context_manager() as cm: + await do_something(cm) + + - description: "list comprehension with async for" + before: | + async def func_name(): + result = CrossSync.rm_aio([x async for x in aiter() if await predicate(x)]) + transformers: [RmAioFunctions] + after: | + async def func_name(): + result = [x for x in aiter() if predicate(x)] + + - description: "multiple async features in one call" + before: | + CrossSync.rm_aio([x async for x in aiter() if await predicate(x)] + await routine()) + transformers: [RmAioFunctions] + after: | + [x for x in aiter() if predicate(x)] + routine() + + - description: "do nothing with no CrossSync.rm_aio" + before: | + async def nested_async(): + async with outer_context(): + async for x in outer_iter(): + async with inner_context(x): + async for y in inner_iter(x): + await process(x, y) + transformers: [RmAioFunctions] + after: | + async def nested_async(): + async with outer_context(): + async for x in outer_iter(): + async with inner_context(x): + async for y in inner_iter(x): + await process(x, y) diff --git a/.cross_sync/test_cross_sync_e2e.py b/.cross_sync/test_cross_sync_e2e.py index 91bcf8335..3dc14f0cb 100644 --- a/.cross_sync/test_cross_sync_e2e.py +++ b/.cross_sync/test_cross_sync_e2e.py @@ -3,6 +3,7 @@ import os import black import pytest +import yaml # add cross_sync to path sys.path.append("google/cloud/bigtable/data/_sync/cross_sync") from transformers import SymbolReplacer, AsyncToSync, RmAioFunctions, CrossSyncClassDecoratorHandler, CrossSyncClassDecoratorHandler @@ -11,10 +12,13 @@ def loader(): dir_name = os.path.join(os.path.dirname(__file__), "test_cases") for file_name in os.listdir(dir_name): + if not file_name.endswith(".yaml"): + print(f"Skipping {file_name}") + continue test_case_file = os.path.join(dir_name, file_name) # load test cases - import yaml with open(test_case_file) as f: + print(f"Loading test cases from {test_case_file}") test_cases = yaml.safe_load(f) for test in test_cases["tests"]: test["file_name"] = file_name From c616143e23b77ed49f3f39240ce748ab498804ee Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 16:57:44 -0700 Subject: [PATCH 246/360] added tests --- .cross_sync/test_cases/rm_aio.yaml | 46 ++++++++++++ .cross_sync/test_cases/symbol_replacer.yaml | 82 +++++++++++++++++++++ .cross_sync/test_cross_sync_e2e.py | 13 +++- .cross_sync/transformers.py | 2 +- 4 files changed, 139 insertions(+), 4 deletions(-) create mode 100644 .cross_sync/test_cases/symbol_replacer.yaml diff --git a/.cross_sync/test_cases/rm_aio.yaml b/.cross_sync/test_cases/rm_aio.yaml index a6434dae8..89acda630 100644 --- a/.cross_sync/test_cases/rm_aio.yaml +++ b/.cross_sync/test_cases/rm_aio.yaml @@ -61,3 +61,49 @@ tests: async with inner_context(x): async for y in inner_iter(x): await process(x, y) + + - description: "nested async for loops with rm_aio" + before: | + async def nested_loops(): + async for x in CrossSync.rm_aio(outer_iter()): + async for y in CrossSync.rm_aio(inner_iter(x)): + await process(x, y) + transformers: [RmAioFunctions] + after: | + async def nested_loops(): + for x in outer_iter(): + for y in inner_iter(x): + await process(x, y) + + - description: "async generator function with rm_aio" + before: | + async def async_gen(): + yield CrossSync.rm_aio(await async_value()) + async for item in CrossSync.rm_aio(async_iterator()): + yield item + transformers: [RmAioFunctions] + after: | + async def async_gen(): + yield async_value() + for item in async_iterator(): + yield item + + - description: "async with statement with multiple context managers" + before: | + async def multi_context(): + async with CrossSync.rm_aio(cm1()), CrossSync.rm_aio(cm2()) as c2, CrossSync.rm_aio(cm3()) as c3: + await do_something(c2, c3) + transformers: [RmAioFunctions] + after: | + async def multi_context(): + with cm1(), cm2() as c2, cm3() as c3: + await do_something(c2, c3) + + - description: "async comprehension with multiple async for and if clauses" + before: | + async def complex_comprehension(): + result = CrossSync.rm_aio([x async for x in aiter1() if await pred1(x) async for y in aiter2(x) if await pred2(y)]) + transformers: [RmAioFunctions] + after: | + async def complex_comprehension(): + result = [x for x in aiter1() if pred1(x) for y in aiter2(x) if pred2(y)] diff --git a/.cross_sync/test_cases/symbol_replacer.yaml b/.cross_sync/test_cases/symbol_replacer.yaml new file mode 100644 index 000000000..fa50045f8 --- /dev/null +++ b/.cross_sync/test_cases/symbol_replacer.yaml @@ -0,0 +1,82 @@ +tests: + - description: "Does not Replace function name" + before: | + def function(): + pass + transformers: + - name: SymbolReplacer + args: + replacements: {"function": "new_function"} + after: | + def function(): + pass + + - description: "Does not replace async function name" + before: | + async def async_func(): + await old_coroutine() + transformers: + - name: SymbolReplacer + args: + replacements: {"async_func": "new_async_func", "old_coroutine": "new_coroutine"} + after: | + async def async_func(): + await new_coroutine() + + - description: "Replace method call" + before: | + result = obj.old_method() + transformers: + - name: SymbolReplacer + args: + replacements: {"old_method": "new_method"} + after: | + result = obj.new_method() + + - description: "Replace in docstring" + before: | + def func(): + """This is a docstring mentioning old_name.""" + pass + transformers: + - name: SymbolReplacer + args: + replacements: {"old_name": "new_name"} + after: | + def func(): + """This is a docstring mentioning new_name.""" + pass + + - description: "Replace in type annotation" + before: | + def func(param: OldType) -> OldReturnType: + pass + transformers: + - name: SymbolReplacer + args: + replacements: {"OldType": "NewType", "OldReturnType": "NewReturnType"} + after: | + def func(param: NewType) -> NewReturnType: + pass + + - description: "Replace in nested attribute" + before: | + result = obj.attr1.attr2.old_attr + transformers: + - name: SymbolReplacer + args: + replacements: {"old_attr": "new_attr"} + after: | + result = obj.attr1.attr2.new_attr + + - description: "No replacement when symbol not found" + before: | + def unchanged_function(): + pass + transformers: + - name: SymbolReplacer + args: + replacements: {"non_existent": "replacement"} + after: | + def unchanged_function(): + pass diff --git a/.cross_sync/test_cross_sync_e2e.py b/.cross_sync/test_cross_sync_e2e.py index 3dc14f0cb..28286830a 100644 --- a/.cross_sync/test_cross_sync_e2e.py +++ b/.cross_sync/test_cross_sync_e2e.py @@ -29,10 +29,17 @@ def loader(): ) def test_e2e_scenario(test_dict): before_ast = ast.parse(test_dict["before"]).body[0] - transformers = [globals()[t] for t in test_dict["transformers"]] got_ast = before_ast - for transformer in transformers: - got_ast = transformer().visit(got_ast) + for transformer_info in test_dict["transformers"]: + # transformer can be passed as a string, or a dict with name and args + if isinstance(transformer_info, str): + transformer_class = globals()[transformer_info] + transformer_args = {} + else: + transformer_class = globals()[transformer_info["name"]] + transformer_args = transformer_info.get("args", {}) + transformer = transformer_class(**transformer_args) + got_ast = transformer.visit(got_ast) final_str = black.format_str(ast.unparse(got_ast), mode=black.FileMode()) expected_str = black.format_str(test_dict["after"], mode=black.FileMode()) assert final_str == expected_str, f"Expected:\n{expected_str}\nGot:\n{final_str}" diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index a5c4eeb6a..903691b4b 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -79,7 +79,7 @@ def visit_FunctionDef(self, node): node.body[0].value.s = docstring return self.generic_visit(node) - def visit_Str(self, node): + def visit_Constant(self, node): """Replace string type annotations""" node.s = self.replacements.get(node.s, node.s) return node From 042f89be63940e54cb1866e0921e7785ad568afe Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 5 Sep 2024 17:20:56 -0700 Subject: [PATCH 247/360] added cross_sync_methods e2e test cases --- .../test_cases/cross_sync_methods.yaml | 144 ++++++++++++++++++ .cross_sync/test_cross_sync_e2e.py | 2 +- 2 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 .cross_sync/test_cases/cross_sync_methods.yaml diff --git a/.cross_sync/test_cases/cross_sync_methods.yaml b/.cross_sync/test_cases/cross_sync_methods.yaml new file mode 100644 index 000000000..ca2222a52 --- /dev/null +++ b/.cross_sync/test_cases/cross_sync_methods.yaml @@ -0,0 +1,144 @@ +tests: + - description: "Convert async method with @CrossSync.convert" + before: | + @CrossSync.convert + async def my_method(self, arg): + pass + transformers: [CrossSyncMethodDecoratorHandler] + after: | + def my_method(self, arg): + pass + + - description: "Convert async method with custom sync name" + before: | + @CrossSync.convert(sync_name="sync_method") + async def async_method(self, arg): + return await self.helper(arg) + transformers: [CrossSyncMethodDecoratorHandler] + after: | + def sync_method(self, arg): + return await self.helper(arg) + + - description: "Convert async method with symbol replacement" + before: | + @CrossSync.convert(replace_symbols={"old": "new"}) + async def my_method(self): + old = 1 + transformers: [CrossSyncMethodDecoratorHandler] + after: | + def my_method(self): + new = 1 + + - description: "Convert async method with rm_aio=True" + before: | + @CrossSync.convert(rm_aio=True) + async def async_method(self): + async with self.lock: + async for item in self.items: + await self.process(item) + transformers: [CrossSyncMethodDecoratorHandler] + after: | + def async_method(self): + with self.lock: + for item in self.items: + self.process(item) + + - description: "Convert async method with docstring formatting" + before: | + @CrossSync.convert(docstring_format_vars={"mode": ("async", "sync")}) + async def async_method(self): + """This is a {mode} method.""" + transformers: [CrossSyncMethodDecoratorHandler] + after: | + def async_method(self): + """This is a sync method.""" + + - description: "Drop method from sync version" + before: | + def keep_method(self): + pass + + @CrossSync.drop_method + async def async_only_method(self): + await self.async_operation() + transformers: [CrossSyncMethodDecoratorHandler] + after: | + def keep_method(self): + pass + + - description: "Convert.pytest" + before: | + @CrossSync.pytest + async def test_async_function(): + result = await async_operation() + assert result == expected_value + transformers: [CrossSyncMethodDecoratorHandler] + after: | + def test_async_function(): + result = async_operation() + assert result == expected_value + + - description: "CrossSync.pytest with rm_aio=False" + before: | + @CrossSync.pytest(rm_aio=False) + async def test_partial_async(): + async with context_manager(): + result = await async_function() + assert result == expected_value + transformers: [CrossSyncMethodDecoratorHandler] + after: | + async def test_partial_async(): + async with context_manager(): + result = await async_function() + assert result == expected_value + + - description: "Convert pytest fixture with custom parameters" + before: | + @CrossSync.pytest_fixture(scope="module", autouse=True) + async def async_fixture(): + resource = await setup_resource() + yield resource + await cleanup_resource(resource) + transformers: [CrossSyncMethodDecoratorHandler] + after: | + @pytest.fixture(scope="module", autouse=True) + async def async_fixture(): + resource = await setup_resource() + yield resource + await cleanup_resource(resource) + + - description: "Convert method with multiple stacked decorators" + before: | + @CrossSync.convert(sync_name="sync_multi_decorated") + @CrossSync.pytest + @some_other_decorator + async def async_multi_decorated(self, arg): + result = await self.async_operation(arg) + return result + transformers: [CrossSyncMethodDecoratorHandler] + after: | + @some_other_decorator + def sync_multi_decorated(self, arg): + result = self.async_operation(arg) + return result + + - description: "Convert method with stacked decorators including rm_aio" + before: | + @CrossSync.convert(rm_aio=True) + @CrossSync.pytest_fixture(scope="function") + @another_decorator + async def async_fixture_with_context(): + async with some_async_context(): + resource = await setup_async_resource() + yield resource + await cleanup_async_resource(resource) + transformers: [CrossSyncMethodDecoratorHandler] + after: | + @pytest.fixture(scope="function") + @another_decorator + def async_fixture_with_context(): + with some_async_context(): + resource = setup_async_resource() + yield resource + cleanup_async_resource(resource) + diff --git a/.cross_sync/test_cross_sync_e2e.py b/.cross_sync/test_cross_sync_e2e.py index 28286830a..b582ac79a 100644 --- a/.cross_sync/test_cross_sync_e2e.py +++ b/.cross_sync/test_cross_sync_e2e.py @@ -6,7 +6,7 @@ import yaml # add cross_sync to path sys.path.append("google/cloud/bigtable/data/_sync/cross_sync") -from transformers import SymbolReplacer, AsyncToSync, RmAioFunctions, CrossSyncClassDecoratorHandler, CrossSyncClassDecoratorHandler +from transformers import SymbolReplacer, AsyncToSync, RmAioFunctions, CrossSyncMethodDecoratorHandler, CrossSyncClassDecoratorHandler def loader(): From f186d6b2f77d12763e330fb5dcb2da25b04c128f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 6 Sep 2024 14:30:42 -0700 Subject: [PATCH 248/360] added tests for class generation --- .../test_cases/cross_sync_classes.yaml | 167 ++++++++++++++++++ .cross_sync/transformers.py | 3 +- 2 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 .cross_sync/test_cases/cross_sync_classes.yaml diff --git a/.cross_sync/test_cases/cross_sync_classes.yaml b/.cross_sync/test_cases/cross_sync_classes.yaml new file mode 100644 index 000000000..f38335e87 --- /dev/null +++ b/.cross_sync/test_cases/cross_sync_classes.yaml @@ -0,0 +1,167 @@ +tests: + - description: "No conversion needed" + before: | + @CrossSync.export_sync(path="example.sync.MyClass") + class MyAsyncClass: + async def my_method(self): + pass + + transformers: + - name: CrossSyncClassDecoratorHandler + args: + file_path: "dummy_path.py" + after: | + class MyClass: + + async def my_method(self): + pass + + - description: "CrossSync.export_sync with replace_symbols" + before: | + @CrossSync.export_sync( + path="example.sync.MyClass", + replace_symbols={"AsyncBase": "SyncBase", "ParentA": "ParentB"} + ) + class MyAsyncClass(ParentA): + def __init__(self, base: AsyncBase): + self.base = base + + transformers: + - name: CrossSyncClassDecoratorHandler + args: + file_path: "dummy_path.py" + after: | + class MyClass(ParentB): + + def __init__(self, base: SyncBase): + self.base = base + + - description: "CrossSync.export_sync with docstring formatting" + before: | + @CrossSync.export_sync( + path="example.sync.MyClass", + docstring_format_vars={"type": ("async", "sync")} + ) + class MyAsyncClass: + """This is a {type} class.""" + + transformers: + - name: CrossSyncClassDecoratorHandler + args: + file_path: "dummy_path.py" + after: | + class MyClass: + """This is a sync class.""" + + - description: "CrossSync.export_sync with multiple decorators and methods" + before: | + @CrossSync.export_sync(path="example.sync.MyClass") + @some_other_decorator + class MyAsyncClass: + @CrossSync.convert + async def my_method(self): + async with self.base.connection(): + return await self.base.my_method() + + @CrossSync.drop_method + async def async_only_method(self): + await self.async_operation() + + def sync_method(self): + return "This method stays the same" + + @CrossSync.pytest_fixture + def fixture(self): + pass + + transformers: + - name: CrossSyncClassDecoratorHandler + args: + file_path: "dummy_path.py" + after: | + @some_other_decorator + class MyClass: + + def my_method(self): + async with self.base.connection(): + return await self.base.my_method() + + def sync_method(self): + return "This method stays the same" + + @pytest.fixture() + def fixture(self): + pass + + - description: "CrossSync.export_sync with nested classes" + before: | + @CrossSync.export_sync(path="example.sync.MyClass", replace_symbols={"AsyncBase": "SyncBase"}) + class MyAsyncClass: + class NestedAsyncClass: + async def nested_method(self, base: AsyncBase): + pass + + @CrossSync.drop_method + async def drop_this_method(self): + pass + + @CrossSync.convert + async def use_nested(self): + nested = self.NestedAsyncClass() + CrossSync.rm_aio(await nested.nested_method()) + transformers: + - name: CrossSyncClassDecoratorHandler + args: + file_path: "dummy_path.py" + after: | + class MyClass: + + class NestedAsyncClass: + + async def nested_method(self, base: SyncBase): + pass + + def use_nested(self): + nested = self.NestedAsyncClass() + nested.nested_method() + + - description: "CrossSync.export_sync with add_mapping" + before: | + @CrossSync.export_sync( + path="example.sync.MyClass", + add_mapping_for_name="MyClass" + ) + class MyAsyncClass: + async def my_method(self): + pass + + transformers: + - name: CrossSyncClassDecoratorHandler + args: + file_path: "dummy_path.py" + after: | + @CrossSync._Sync_Impl.add_mapping_decorator("MyClass") + class MyClass: + + async def my_method(self): + pass + + - description: "CrossSync.export_sync with CrossSync calls" + before: | + @CrossSync.export_sync(path="example.sync.MyClass") + class MyAsyncClass: + @CrossSync.convert + async def my_method(self): + async with CrossSync.rm_aio(CrossSync.Condition()) as c: + CrossSync.rm_aio(await CrossSync.yield_to_event_loop()) + + transformers: + - name: CrossSyncClassDecoratorHandler + args: + file_path: "dummy_path.py" + after: | + class MyClass: + + def my_method(self): + with CrossSync._Sync_Impl.Condition() as c: + CrossSync._Sync_Impl.yield_to_event_loop() diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 903691b4b..b6a34c690 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -271,6 +271,7 @@ def visit_ClassDef(self, node): and avoid duplicate writes """ try: + converted = None for decorator in node.decorator_list: try: handler = AstDecorator.get_for_node(decorator) @@ -300,7 +301,7 @@ def visit_ClassDef(self, node): self._artifact_dict[out_file] = output_artifact except ValueError: continue - return node + return converted except ValueError as e: raise ValueError(f"failed for class: {node.name}") from e From 7a0c638bec0fd82eecc9387d58663c4bc10da794 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 6 Sep 2024 14:48:31 -0700 Subject: [PATCH 249/360] moved cross_sync tests into system tests --- noxfile.py | 4 +++- .../system/cross_sync}/test_cases/async_to_sync.yaml | 0 .../system/cross_sync}/test_cases/cross_sync_classes.yaml | 0 .../system/cross_sync}/test_cases/cross_sync_methods.yaml | 0 .../system/cross_sync}/test_cases/rm_aio.yaml | 0 .../system/cross_sync}/test_cases/symbol_replacer.yaml | 0 .../system/cross_sync}/test_cross_sync_e2e.py | 8 ++++++-- 7 files changed, 9 insertions(+), 3 deletions(-) rename {.cross_sync => tests/system/cross_sync}/test_cases/async_to_sync.yaml (100%) rename {.cross_sync => tests/system/cross_sync}/test_cases/cross_sync_classes.yaml (100%) rename {.cross_sync => tests/system/cross_sync}/test_cases/cross_sync_methods.yaml (100%) rename {.cross_sync => tests/system/cross_sync}/test_cases/rm_aio.yaml (100%) rename {.cross_sync => tests/system/cross_sync}/test_cases/symbol_replacer.yaml (100%) rename {.cross_sync => tests/system/cross_sync}/test_cross_sync_e2e.py (84%) diff --git a/noxfile.py b/noxfile.py index 5fb94526d..22a3a4ed0 100644 --- a/noxfile.py +++ b/noxfile.py @@ -48,7 +48,7 @@ UNIT_TEST_EXTRAS: List[str] = [] UNIT_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {} -SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.8"] +SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.8", "3.12"] SYSTEM_TEST_STANDARD_DEPENDENCIES: List[str] = [ "mock", "pytest", @@ -56,6 +56,8 @@ ] SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ "pytest-asyncio==0.21.2", + "black==23.7.0", + "pyyaml==6.0.2", ] SYSTEM_TEST_LOCAL_DEPENDENCIES: List[str] = [] SYSTEM_TEST_DEPENDENCIES: List[str] = [] diff --git a/.cross_sync/test_cases/async_to_sync.yaml b/tests/system/cross_sync/test_cases/async_to_sync.yaml similarity index 100% rename from .cross_sync/test_cases/async_to_sync.yaml rename to tests/system/cross_sync/test_cases/async_to_sync.yaml diff --git a/.cross_sync/test_cases/cross_sync_classes.yaml b/tests/system/cross_sync/test_cases/cross_sync_classes.yaml similarity index 100% rename from .cross_sync/test_cases/cross_sync_classes.yaml rename to tests/system/cross_sync/test_cases/cross_sync_classes.yaml diff --git a/.cross_sync/test_cases/cross_sync_methods.yaml b/tests/system/cross_sync/test_cases/cross_sync_methods.yaml similarity index 100% rename from .cross_sync/test_cases/cross_sync_methods.yaml rename to tests/system/cross_sync/test_cases/cross_sync_methods.yaml diff --git a/.cross_sync/test_cases/rm_aio.yaml b/tests/system/cross_sync/test_cases/rm_aio.yaml similarity index 100% rename from .cross_sync/test_cases/rm_aio.yaml rename to tests/system/cross_sync/test_cases/rm_aio.yaml diff --git a/.cross_sync/test_cases/symbol_replacer.yaml b/tests/system/cross_sync/test_cases/symbol_replacer.yaml similarity index 100% rename from .cross_sync/test_cases/symbol_replacer.yaml rename to tests/system/cross_sync/test_cases/symbol_replacer.yaml diff --git a/.cross_sync/test_cross_sync_e2e.py b/tests/system/cross_sync/test_cross_sync_e2e.py similarity index 84% rename from .cross_sync/test_cross_sync_e2e.py rename to tests/system/cross_sync/test_cross_sync_e2e.py index b582ac79a..09ba62ab4 100644 --- a/.cross_sync/test_cross_sync_e2e.py +++ b/tests/system/cross_sync/test_cross_sync_e2e.py @@ -5,12 +5,15 @@ import pytest import yaml # add cross_sync to path -sys.path.append("google/cloud/bigtable/data/_sync/cross_sync") +test_dir_name = os.path.dirname(__file__) +cross_sync_path = os.path.join(test_dir_name, "..", "..", "..", ".cross_sync") +sys.path.append(cross_sync_path) + from transformers import SymbolReplacer, AsyncToSync, RmAioFunctions, CrossSyncMethodDecoratorHandler, CrossSyncClassDecoratorHandler def loader(): - dir_name = os.path.join(os.path.dirname(__file__), "test_cases") + dir_name = os.path.join(test_dir_name, "test_cases") for file_name in os.listdir(dir_name): if not file_name.endswith(".yaml"): print(f"Skipping {file_name}") @@ -27,6 +30,7 @@ def loader(): @pytest.mark.parametrize( "test_dict", loader(), ids=lambda x: f"{x['file_name']}: {x.get('description', '')}" ) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="ast.unparse requires python3.9 or higher") def test_e2e_scenario(test_dict): before_ast = ast.parse(test_dict["before"]).body[0] got_ast = before_ast From a3cb9a6a6614bb9cbbfae48a2ed4c12bcc676a58 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 6 Sep 2024 14:54:10 -0700 Subject: [PATCH 250/360] fixed lint issues --- .../data/_sync/cross_sync/_decorators.py | 30 ++- .../system/cross_sync/test_cross_sync_e2e.py | 14 +- tests/unit/data/_sync/test_cross_sync.py | 63 ++++-- .../data/_sync/test_cross_sync_decorators.py | 179 ++++++++++++------ 4 files changed, 196 insertions(+), 90 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index 6c94b0913..9c83079fb 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -196,8 +196,12 @@ def __init__( self.path = path self.replace_symbols = replace_symbols docstring_format_vars = docstring_format_vars or {} - self.async_docstring_format_vars = {k: v[0] for k, v in docstring_format_vars.items()} - self.sync_docstring_format_vars = {k: v[1] for k, v in docstring_format_vars.items()} + self.async_docstring_format_vars = { + k: v[0] for k, v in docstring_format_vars.items() + } + self.sync_docstring_format_vars = { + k: v[1] for k, v in docstring_format_vars.items() + } self.mypy_ignore = mypy_ignore self.include_file_imports = include_file_imports self.add_mapping_for_name = add_mapping_for_name @@ -271,7 +275,9 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): # update docstring if specified if self.sync_docstring_format_vars: docstring = ast.get_docstring(wrapped_node) - wrapped_node.body[0].value.s = docstring.format(**self.sync_docstring_format_vars) + wrapped_node.body[0].value.s = docstring.format( + **self.sync_docstring_format_vars + ) return wrapped_node @@ -299,8 +305,12 @@ def __init__( self.sync_name = sync_name self.replace_symbols = replace_symbols docstring_format_vars = docstring_format_vars or {} - self.async_docstring_format_vars = {k: v[0] for k, v in docstring_format_vars.items()} - self.sync_docstring_format_vars = {k: v[1] for k, v in docstring_format_vars.items()} + self.async_docstring_format_vars = { + k: v[0] for k, v in docstring_format_vars.items() + } + self.sync_docstring_format_vars = { + k: v[1] for k, v in docstring_format_vars.items() + } self.rm_aio = rm_aio def sync_ast_transform(self, wrapped_node, transformers_globals): @@ -315,7 +325,9 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): wrapped_node.name, wrapped_node.args, wrapped_node.body, - wrapped_node.decorator_list if hasattr(wrapped_node, "decorator_list") else [], + wrapped_node.decorator_list + if hasattr(wrapped_node, "decorator_list") + else [], wrapped_node.returns if hasattr(wrapped_node, "returns") else None, ), wrapped_node, @@ -333,7 +345,9 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): # update docstring if specified if self.sync_docstring_format_vars: docstring = ast.get_docstring(wrapped_node) - wrapped_node.body[0].value.s = docstring.format(**self.sync_docstring_format_vars) + wrapped_node.body[0].value.s = docstring.format( + **self.sync_docstring_format_vars + ) return wrapped_node def async_decorator(self): @@ -342,9 +356,11 @@ def async_decorator(self): """ if self.async_docstring_format_vars: + def decorator(f): f.__doc__ = f.__doc__.format(**self.async_docstring_format_vars) return f + return decorator else: return None diff --git a/tests/system/cross_sync/test_cross_sync_e2e.py b/tests/system/cross_sync/test_cross_sync_e2e.py index 09ba62ab4..489e042fe 100644 --- a/tests/system/cross_sync/test_cross_sync_e2e.py +++ b/tests/system/cross_sync/test_cross_sync_e2e.py @@ -4,12 +4,19 @@ import black import pytest import yaml + # add cross_sync to path test_dir_name = os.path.dirname(__file__) cross_sync_path = os.path.join(test_dir_name, "..", "..", "..", ".cross_sync") sys.path.append(cross_sync_path) -from transformers import SymbolReplacer, AsyncToSync, RmAioFunctions, CrossSyncMethodDecoratorHandler, CrossSyncClassDecoratorHandler +from transformers import ( # noqa: F401 E402 + SymbolReplacer, + AsyncToSync, + RmAioFunctions, + CrossSyncMethodDecoratorHandler, + CrossSyncClassDecoratorHandler, +) def loader(): @@ -27,10 +34,13 @@ def loader(): test["file_name"] = file_name yield test + @pytest.mark.parametrize( "test_dict", loader(), ids=lambda x: f"{x['file_name']}: {x.get('description', '')}" ) -@pytest.mark.skipif(sys.version_info < (3, 9), reason="ast.unparse requires python3.9 or higher") +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="ast.unparse requires python3.9 or higher" +) def test_e2e_scenario(test_dict): before_ast = ast.parse(test_dict["before"]).body[0] got_ast = before_ast diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index bd9fe524e..8db4670e5 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -24,8 +24,8 @@ from google.cloud.bigtable.data._sync.cross_sync.cross_sync import CrossSync, T from unittest import mock -class TestCrossSync: +class TestCrossSync: async def async_iter(self, in_list): for i in in_list: yield i @@ -38,14 +38,22 @@ def cs_sync(self): def cs_async(self): return CrossSync - @pytest.mark.parametrize( - "attr, async_version, sync_version", [ + "attr, async_version, sync_version", + [ ("is_async", True, False), ("sleep", asyncio.sleep, time.sleep), ("wait", asyncio.wait, concurrent.futures.wait), - ("retry_target", api_core.retry.retry_target_async, api_core.retry.retry_target), - ("retry_target_stream", api_core.retry.retry_target_stream_async, api_core.retry.retry_target_stream), + ( + "retry_target", + api_core.retry.retry_target_async, + api_core.retry.retry_target, + ), + ( + "retry_target_stream", + api_core.retry.retry_target_stream_async, + api_core.retry.retry_target_stream, + ), ("Retry", api_core.retry.AsyncRetry, api_core.retry.Retry), ("Queue", asyncio.Queue, queue.Queue), ("Condition", asyncio.Condition, threading.Condition), @@ -59,14 +67,18 @@ def cs_async(self): ("Iterable", typing.AsyncIterable, typing.Iterable), ("Iterator", typing.AsyncIterator, typing.Iterator), ("Generator", typing.AsyncGenerator, typing.Generator), - ] + ], ) - def test_alias_attributes(self, attr, async_version, sync_version, cs_sync, cs_async): + def test_alias_attributes( + self, attr, async_version, sync_version, cs_sync, cs_async + ): """ Test basic alias attributes, to ensure they point to the right place in both sync and async versions. """ - assert getattr(cs_async, attr) == async_version, f"Failed async version for {attr}" + assert ( + getattr(cs_async, attr) == async_version + ), f"Failed async version for {attr}" assert getattr(cs_sync, attr) == sync_version, f"Failed sync version for {attr}" @pytest.mark.asyncio @@ -121,7 +133,7 @@ def test_gather_partials_with_excepptions(self, cs_sync): Test sync version of CrossSync.gather_partials() with exceptions """ with concurrent.futures.ThreadPoolExecutor() as e: - partials = [lambda i=i: i + 1 if i != 3 else 1/0 for i in range(5)] + partials = [lambda i=i: i + 1 if i != 3 else 1 / 0 for i in range(5)] with pytest.raises(ZeroDivisionError): cs_sync.gather_partials(partials, sync_executor=e) @@ -130,8 +142,10 @@ def test_gather_partials_return_exceptions(self, cs_sync): Test sync version of CrossSync.gather_partials() with return_exceptions=True """ with concurrent.futures.ThreadPoolExecutor() as e: - partials = [lambda i=i: i + 1 if i != 3 else 1/0 for i in range(5)] - results = cs_sync.gather_partials(partials, return_exceptions=True, sync_executor=e) + partials = [lambda i=i: i + 1 if i != 3 else 1 / 0 for i in range(5)] + results = cs_sync.gather_partials( + partials, return_exceptions=True, sync_executor=e + ) assert len(results) == 5 assert results[0] == 1 assert results[1] == 2 @@ -145,7 +159,7 @@ def test_gather_partials_no_executor(self, cs_sync): """ partials = [lambda i=i: i + 1 for i in range(5)] with pytest.raises(ValueError) as e: - results = cs_sync.gather_partials(partials) + cs_sync.gather_partials(partials) assert "sync_executor is required" in str(e.value) @pytest.mark.asyncio @@ -153,6 +167,7 @@ async def test_gather_partials_async(self, cs_async): """ Test async version of CrossSync.gather_partials() """ + async def coro(i): return i + 1 @@ -165,8 +180,9 @@ async def test_gather_partials_async_with_exceptions(self, cs_async): """ Test async version of CrossSync.gather_partials() with exceptions """ + async def coro(i): - return i + 1 if i != 3 else 1/0 + return i + 1 if i != 3 else 1 / 0 partials = [functools.partial(coro, i) for i in range(5)] with pytest.raises(ZeroDivisionError): @@ -177,8 +193,9 @@ async def test_gather_partials_async_return_exceptions(self, cs_async): """ Test async version of CrossSync.gather_partials() with return_exceptions=True """ + async def coro(i): - return i + 1 if i != 3 else 1/0 + return i + 1 if i != 3 else 1 / 0 partials = [functools.partial(coro, i) for i in range(5)] results = await cs_async.gather_partials(partials, return_exceptions=True) @@ -194,13 +211,16 @@ async def test_gather_partials_async_uses_asyncio_gather(self, cs_async): """ CrossSync.gather_partials() should use asyncio.gather() internally """ + async def coro(i): return i + 1 - return_exceptions=object() + return_exceptions = object() partials = [functools.partial(coro, i) for i in range(5)] with mock.patch.object(asyncio, "gather", mock.AsyncMock()) as gather: - await cs_async.gather_partials(partials, return_exceptions=return_exceptions) + await cs_async.gather_partials( + partials, return_exceptions=return_exceptions + ) gather.assert_called_once() found_args, found_kwargs = gather.call_args assert found_kwargs["return_exceptions"] == return_exceptions @@ -249,7 +269,6 @@ async def test_event_wait_async(self, cs_async, break_early): await cs_async.event_wait(event, async_break_early=break_early) event.wait.assert_called_once_with() - @pytest.mark.asyncio async def test_event_wait_async_with_timeout(self, cs_async): """ @@ -308,7 +327,7 @@ def test_create_task(self, cs_sync): Test creating Future using create_task() """ executor = concurrent.futures.ThreadPoolExecutor() - fn = lambda x, y: x + y + fn = lambda x, y: x + y # noqa: E731 result = cs_sync.create_task(fn, 1, y=4, sync_executor=executor) assert isinstance(result, cs_sync.Task) assert result.result() == 5 @@ -327,7 +346,6 @@ def test_create_task_passthrough(self, cs_sync): assert executor.submit.call_count == 1 assert executor.submit.call_args == ((fn, *args), kwargs) - def test_create_task_no_executor(self, cs_sync): """ if no executor is provided, raise an exception @@ -341,8 +359,10 @@ async def test_create_task_async(self, cs_async): """ Test creating Future using create_task() """ + async def coro_fn(x, y): return x + y + result = cs_async.create_task(coro_fn, 1, y=4) assert isinstance(result, asyncio.Task) assert await result == 5 @@ -358,6 +378,7 @@ async def test_create_task_async_passthrough(self, cs_async): kwargs = {"a": 1, "b": 2} with mock.patch.object(asyncio, "create_task", mock.Mock()) as create_task: result = cs_async.create_task(coro_fn, *args, **kwargs) + assert isinstance(result, asyncio.Task) create_task.assert_called_once() create_task.assert_called_once_with(coro_fn.return_value) coro_fn.assert_called_once_with(*args, **kwargs) @@ -367,8 +388,10 @@ async def test_create_task_async_with_name(self, cs_async): """ Test creating a task with a name """ + async def coro_fn(): return None + name = "test-name-456" result = cs_async.create_task(coro_fn, task_name=name) assert isinstance(result, asyncio.Task) @@ -445,7 +468,9 @@ def test_add_mapping_decorator(self, cs_sync, cs_async): add_mapping_decorator should allow wrapping classes with add_mapping() """ for cls in [cs_sync, cs_async]: + @cls.add_mapping_decorator("decorated") class Decorated: pass + assert cls.decorated == Decorated diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_sync/test_cross_sync_decorators.py index d31a1d0ed..988d8d113 100644 --- a/tests/unit/data/_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_sync/test_cross_sync_decorators.py @@ -17,19 +17,27 @@ import ast from unittest import mock from google.cloud.bigtable.data._sync.cross_sync.cross_sync import CrossSync -from google.cloud.bigtable.data._sync.cross_sync._decorators import ExportSync, Convert, DropMethod, Pytest, PytestFixture +from google.cloud.bigtable.data._sync.cross_sync._decorators import ( + ExportSync, + Convert, + DropMethod, + Pytest, + PytestFixture, +) @pytest.fixture def globals_mock(): mock_transform = mock.Mock() mock_transform().visit = lambda x: x - global_dict = {k: mock_transform for k in ["RmAioFunctions", "SymbolReplacer", "CrossSyncMethodDecoratorHandler"]} + global_dict = { + k: mock_transform + for k in ["RmAioFunctions", "SymbolReplacer", "CrossSyncMethodDecoratorHandler"] + } return global_dict class TestExportSyncDecorator: - def _get_class(self): return ExportSync @@ -44,7 +52,7 @@ def test_ctor_defaults(self): instance = self._get_class()(path) assert instance.path is path assert instance.replace_symbols is None - assert instance.mypy_ignore is () + assert instance.mypy_ignore == () assert instance.include_file_imports is True assert instance.add_mapping_for_name is None assert instance.async_docstring_format_vars == {} @@ -64,7 +72,7 @@ def test_ctor(self): docstring_format_vars=docstring_format_vars, mypy_ignore=mypy_ignore, include_file_imports=include_file_imports, - add_mapping_for_name=add_mapping_for_name + add_mapping_for_name=add_mapping_for_name, ) assert instance.path is path assert instance.replace_symbols is replace_symbols @@ -97,22 +105,27 @@ def test_class_decorator_adds_mapping(self): assert add_mapping.call_count == 1 add_mapping.assert_called_once_with(name, mock_cls) - @pytest.mark.parametrize("docstring,format_vars,expected", [ - ["test docstring", {}, "test docstring"], - ["{}", {}, "{}"], - ["test_docstring", {"A": (1, 2)}, "test_docstring"], - ["{A}", {"A": (1, 2)}, "1"], - ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "1 3"], - ["hello {world_var}", {"world_var": ("world", "moon")}, "hello world"], - ]) + @pytest.mark.parametrize( + "docstring,format_vars,expected", + [ + ["test docstring", {}, "test docstring"], + ["{}", {}, "{}"], + ["test_docstring", {"A": (1, 2)}, "test_docstring"], + ["{A}", {"A": (1, 2)}, "1"], + ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "1 3"], + ["hello {world_var}", {"world_var": ("world", "moon")}, "hello world"], + ], + ) def test_class_decorator_docstring_update(self, docstring, format_vars, expected): """ If docstring_format_vars is set, should update the docstring of the class being decorated """ + @ExportSync.decorator(path=1, docstring_format_vars=format_vars) class Class: __doc__ = docstring + assert Class.__doc__ == expected # check internal state instance = self._get_class()(path=1, docstring_format_vars=format_vars) @@ -138,16 +151,30 @@ def test_sync_ast_transform_strips_cross_sync_decorators(self, globals_mock): should remove all CrossSync decorators from the class """ decorator = self._get_class()("path") - cross_sync_decorator = ast.Call(func=ast.Attribute(value=ast.Name(id='CrossSync', ctx=ast.Load()), attr='some_decorator', ctx=ast.Load()), args=[], keywords=[]) - other_decorator = ast.Name(id='other_decorator', ctx=ast.Load()) - mock_node = ast.ClassDef(name="AsyncClass", bases=[], keywords=[], body=[], decorator_list=[cross_sync_decorator, other_decorator]) + cross_sync_decorator = ast.Call( + func=ast.Attribute( + value=ast.Name(id="CrossSync", ctx=ast.Load()), + attr="some_decorator", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + other_decorator = ast.Name(id="other_decorator", ctx=ast.Load()) + mock_node = ast.ClassDef( + name="AsyncClass", + bases=[], + keywords=[], + body=[], + decorator_list=[cross_sync_decorator, other_decorator], + ) result = decorator.sync_ast_transform(mock_node, globals_mock) assert isinstance(result, ast.ClassDef) assert len(result.decorator_list) == 1 assert isinstance(result.decorator_list[0], ast.Name) - assert result.decorator_list[0].id == 'other_decorator' + assert result.decorator_list[0].id == "other_decorator" def test_sync_ast_transform_add_mapping(self, globals_mock): """ @@ -162,27 +189,34 @@ def test_sync_ast_transform_add_mapping(self, globals_mock): assert len(result.decorator_list) == 1 assert isinstance(result.decorator_list[0], ast.Call) assert isinstance(result.decorator_list[0].func, ast.Attribute) - assert result.decorator_list[0].func.attr == 'add_mapping_decorator' - assert result.decorator_list[0].args[0].value == 'sync_class' - - @pytest.mark.parametrize("docstring,format_vars,expected", [ - ["test docstring", {}, "test docstring"], - ["{}", {}, "{}"], - ["test_docstring", {"A": (1, 2)}, "test_docstring"], - ["{A}", {"A": (1, 2)}, "2"], - ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "2 4"], - ["hello {world_var}", {"world_var": ("world", "moon")}, "hello moon"], - ]) - def test_sync_ast_transform_add_docstring_format(self, docstring, format_vars, expected, globals_mock): + assert result.decorator_list[0].func.attr == "add_mapping_decorator" + assert result.decorator_list[0].args[0].value == "sync_class" + + @pytest.mark.parametrize( + "docstring,format_vars,expected", + [ + ["test docstring", {}, "test docstring"], + ["{}", {}, "{}"], + ["test_docstring", {"A": (1, 2)}, "test_docstring"], + ["{A}", {"A": (1, 2)}, "2"], + ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "2 4"], + ["hello {world_var}", {"world_var": ("world", "moon")}, "hello moon"], + ], + ) + def test_sync_ast_transform_add_docstring_format( + self, docstring, format_vars, expected, globals_mock + ): """ If docstring_format_vars is set, should format the docstring of the new class """ - decorator = self._get_class()("path.to.SyncClass", docstring_format_vars=format_vars) + decorator = self._get_class()( + "path.to.SyncClass", docstring_format_vars=format_vars + ) mock_node = ast.ClassDef( name="AsyncClass", bases=[], keywords=[], - body=[ast.Expr(value=ast.Constant(value=docstring))] + body=[ast.Expr(value=ast.Constant(value=docstring))], ) result = decorator.sync_ast_transform(mock_node, globals_mock) @@ -216,7 +250,9 @@ def test_sync_ast_transform_replace_symbols(self, globals_mock): SymbolReplacer should be called with replace_symbols """ replace_symbols = {"a": "b", "c": "d"} - decorator = self._get_class()("path.to.SyncClass", replace_symbols=replace_symbols) + decorator = self._get_class()( + "path.to.SyncClass", replace_symbols=replace_symbols + ) mock_node = ast.ClassDef(name="AsyncClass", bases=[], keywords=[], body=[]) symbol_transform_mock = mock.Mock() globals_mock = {**globals_mock, "SymbolReplacer": symbol_transform_mock} @@ -232,7 +268,6 @@ def test_sync_ast_transform_replace_symbols(self, globals_mock): class TestConvertDecorator: - def _get_class(self): return Convert @@ -254,7 +289,7 @@ def test_ctor(self): sync_name=sync_name, replace_symbols=replace_symbols, docstring_format_vars=docstring_format_vars, - rm_aio=rm_aio + rm_aio=rm_aio, ) assert instance.sync_name is sync_name assert instance.replace_symbols is replace_symbols @@ -270,22 +305,27 @@ def test_async_decorator_no_docstring(self): wrapped_class = self._get_class().decorator(unwrapped_class) assert unwrapped_class == wrapped_class - @pytest.mark.parametrize("docstring,format_vars,expected", [ - ["test docstring", {}, "test docstring"], - ["{}", {}, "{}"], - ["test_docstring", {"A": (1, 2)}, "test_docstring"], - ["{A}", {"A": (1, 2)}, "1"], - ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "1 3"], - ["hello {world_var}", {"world_var": ("world", "moon")}, "hello world"], - ]) + @pytest.mark.parametrize( + "docstring,format_vars,expected", + [ + ["test docstring", {}, "test docstring"], + ["{}", {}, "{}"], + ["test_docstring", {"A": (1, 2)}, "test_docstring"], + ["{A}", {"A": (1, 2)}, "1"], + ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "1 3"], + ["hello {world_var}", {"world_var": ("world", "moon")}, "hello world"], + ], + ) def test_async_decorator_docstring_update(self, docstring, format_vars, expected): """ If docstring_format_vars is set, should update the docstring of the class being decorated """ + @Convert.decorator(docstring_format_vars=format_vars) class Class: __doc__ = docstring + assert Class.__doc__ == expected # check internal state instance = self._get_class()(docstring_format_vars=format_vars) @@ -299,7 +339,9 @@ def test_sync_ast_transform_remove_adef(self): Should convert `async def` methods to `def` methods """ decorator = self._get_class()() - mock_node = ast.AsyncFunctionDef(name="test_method", args=ast.arguments(), body=[]) + mock_node = ast.AsyncFunctionDef( + name="test_method", args=ast.arguments(), body=[] + ) result = decorator.sync_ast_transform(mock_node, {}) @@ -311,7 +353,9 @@ def test_sync_ast_transform_replaces_name(self, globals_mock): Should update the name of the method if sync_name is set """ decorator = self._get_class()(sync_name="new_method_name") - mock_node = ast.AsyncFunctionDef(name="old_method_name", args=ast.arguments(), body=[]) + mock_node = ast.AsyncFunctionDef( + name="old_method_name", args=ast.arguments(), body=[] + ) result = decorator.sync_ast_transform(mock_node, globals_mock) @@ -323,7 +367,9 @@ def test_sync_ast_transform_calls_async_to_sync(self): Should call AsyncToSync if rm_aio is set """ decorator = self._get_class()(rm_aio=True) - mock_node = ast.AsyncFunctionDef(name="test_method", args=ast.arguments(), body=[]) + mock_node = ast.AsyncFunctionDef( + name="test_method", args=ast.arguments(), body=[] + ) async_to_sync_mock = mock.Mock() async_to_sync_mock.visit.return_value = mock_node globals_mock = {"AsyncToSync": lambda: async_to_sync_mock} @@ -337,7 +383,9 @@ def test_sync_ast_transform_replace_symbols(self): """ replace_symbols = {"old_symbol": "new_symbol"} decorator = self._get_class()(replace_symbols=replace_symbols) - mock_node = ast.AsyncFunctionDef(name="test_method", args=ast.arguments(), body=[]) + mock_node = ast.AsyncFunctionDef( + name="test_method", args=ast.arguments(), body=[] + ) symbol_replacer_mock = mock.Mock() globals_mock = {"SymbolReplacer": symbol_replacer_mock} @@ -347,15 +395,20 @@ def test_sync_ast_transform_replace_symbols(self): assert symbol_replacer_mock.call_args[0][0] == replace_symbols assert symbol_replacer_mock(replace_symbols).visit.call_count == 1 - @pytest.mark.parametrize("docstring,format_vars,expected", [ - ["test docstring", {}, "test docstring"], - ["{}", {}, "{}"], - ["test_docstring", {"A": (1, 2)}, "test_docstring"], - ["{A}", {"A": (1, 2)}, "2"], - ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "2 4"], - ["hello {world_var}", {"world_var": ("world", "moon")}, "hello moon"], - ]) - def test_sync_ast_transform_add_docstring_format(self, docstring, format_vars, expected): + @pytest.mark.parametrize( + "docstring,format_vars,expected", + [ + ["test docstring", {}, "test docstring"], + ["{}", {}, "{}"], + ["test_docstring", {"A": (1, 2)}, "test_docstring"], + ["{A}", {"A": (1, 2)}, "2"], + ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "2 4"], + ["hello {world_var}", {"world_var": ("world", "moon")}, "hello moon"], + ], + ) + def test_sync_ast_transform_add_docstring_format( + self, docstring, format_vars, expected + ): """ If docstring_format_vars is set, should format the docstring of the new method """ @@ -363,7 +416,7 @@ def test_sync_ast_transform_add_docstring_format(self, docstring, format_vars, e mock_node = ast.AsyncFunctionDef( name="test_method", args=ast.arguments(), - body=[ast.Expr(value=ast.Constant(value=docstring))] + body=[ast.Expr(value=ast.Constant(value=docstring))], ) result = decorator.sync_ast_transform(mock_node, {}) @@ -375,7 +428,6 @@ def test_sync_ast_transform_add_docstring_format(self, docstring, format_vars, e class TestDropMethodDecorator: - def _get_class(self): return DropMethod @@ -383,7 +435,7 @@ def test_decorator_functionality(self): """ applying the decorator should be a no-op """ - unwrapped = lambda x: x + unwrapped = lambda x: x # noqa: E731 wrapped = self._get_class().decorator(unwrapped) assert unwrapped == wrapped assert unwrapped(1) == wrapped(1) @@ -394,7 +446,9 @@ def test_sync_ast_transform(self): Should return None for any input method """ decorator = self._get_class()() - mock_node = ast.AsyncFunctionDef(name="test_method", args=ast.arguments(), body=[]) + mock_node = ast.AsyncFunctionDef( + name="test_method", args=ast.arguments(), body=[] + ) result = decorator.sync_ast_transform(mock_node, {}) @@ -402,7 +456,6 @@ def test_sync_ast_transform(self): class TestPytestDecorator: - def _get_class(self): return Pytest @@ -447,7 +500,6 @@ def test_sync_ast_transform_rm_aio(self): class TestPytestFixtureDecorator: - def _get_class(self): return PytestFixture @@ -456,6 +508,7 @@ def test_decorator_functionality(self): Should wrap the class with pytest_asyncio.fixture """ with mock.patch.object(pytest_asyncio, "fixture") as fixture: + @PytestFixture.decorator(1, 2, scope="function", params=[3, 4]) def fn(): pass @@ -470,7 +523,9 @@ def test_sync_ast_transform(self): """ decorator = self._get_class()(1, 2, scope="function") - mock_node = ast.AsyncFunctionDef(name="test_method", args=ast.arguments(), body=[]) + mock_node = ast.AsyncFunctionDef( + name="test_method", args=ast.arguments(), body=[] + ) result = decorator.sync_ast_transform(mock_node, {}) From 0c91eb7f450667479e241900361e647131869b7a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 6 Sep 2024 15:08:40 -0700 Subject: [PATCH 251/360] fixed mypy --- .../bigtable/data/_sync/cross_sync/_decorators.py | 14 ++++++++------ tests/unit/data/_sync/test_cross_sync.py | 1 - 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index 9c83079fb..860a7cfb3 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -275,9 +275,10 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): # update docstring if specified if self.sync_docstring_format_vars: docstring = ast.get_docstring(wrapped_node) - wrapped_node.body[0].value.s = docstring.format( - **self.sync_docstring_format_vars - ) + if docstring: + wrapped_node.body[0].value.s = docstring.format( + **self.sync_docstring_format_vars + ) return wrapped_node @@ -345,9 +346,10 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): # update docstring if specified if self.sync_docstring_format_vars: docstring = ast.get_docstring(wrapped_node) - wrapped_node.body[0].value.s = docstring.format( - **self.sync_docstring_format_vars - ) + if docstring: + wrapped_node.body[0].value.s = docstring.format( + **self.sync_docstring_format_vars + ) return wrapped_node def async_decorator(self): diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index 8db4670e5..fd0295f78 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -378,7 +378,6 @@ async def test_create_task_async_passthrough(self, cs_async): kwargs = {"a": 1, "b": 2} with mock.patch.object(asyncio, "create_task", mock.Mock()) as create_task: result = cs_async.create_task(coro_fn, *args, **kwargs) - assert isinstance(result, asyncio.Task) create_task.assert_called_once() create_task.assert_called_once_with(coro_fn.return_value) coro_fn.assert_called_once_with(*args, **kwargs) From 5484f5e9300d87f27e94171353d0b36a38fe2fff Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 6 Sep 2024 15:10:37 -0700 Subject: [PATCH 252/360] changed system_emulated version --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index 22a3a4ed0..2c21f2425 100644 --- a/noxfile.py +++ b/noxfile.py @@ -258,7 +258,7 @@ def install_systemtest_dependencies(session, *constraints): session.install("-e", ".", *constraints) -@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) +@nox.session(python=3.8) def system_emulated(session): import subprocess import signal From e7881da31c545af6286ed4bcb3e19de4fabd6cca Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 6 Sep 2024 15:14:58 -0700 Subject: [PATCH 253/360] fixed system test version format --- noxfile.py | 2 +- tests/unit/data/_sync/test_cross_sync.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/noxfile.py b/noxfile.py index 2c21f2425..1e153efe2 100644 --- a/noxfile.py +++ b/noxfile.py @@ -258,7 +258,7 @@ def install_systemtest_dependencies(session, *constraints): session.install("-e", ".", *constraints) -@nox.session(python=3.8) +@nox.session(python="3.8") def system_emulated(session): import subprocess import signal diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index fd0295f78..4394889d5 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -377,7 +377,7 @@ async def test_create_task_async_passthrough(self, cs_async): args = [1, 2, 3] kwargs = {"a": 1, "b": 2} with mock.patch.object(asyncio, "create_task", mock.Mock()) as create_task: - result = cs_async.create_task(coro_fn, *args, **kwargs) + cs_async.create_task(coro_fn, *args, **kwargs) create_task.assert_called_once() create_task.assert_called_once_with(coro_fn.return_value) coro_fn.assert_called_once_with(*args, **kwargs) From bc67b30ffb8db85f6e14733dacfb14f2fb2575a8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 6 Sep 2024 15:38:08 -0700 Subject: [PATCH 254/360] fixed unit-3.7 issues --- .../data/_sync/cross_sync/_decorators.py | 8 +++--- .../data/_sync/cross_sync/cross_sync.py | 6 +---- tests/unit/data/_async/test_client.py | 3 +-- tests/unit/data/_sync/test_cross_sync.py | 27 ++++++++++++------- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index 860a7cfb3..b1456951b 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -276,8 +276,8 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): if self.sync_docstring_format_vars: docstring = ast.get_docstring(wrapped_node) if docstring: - wrapped_node.body[0].value.s = docstring.format( - **self.sync_docstring_format_vars + wrapped_node.body[0].value = ast.Constant( + value=docstring.format(**self.sync_docstring_format_vars) ) return wrapped_node @@ -347,8 +347,8 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): if self.sync_docstring_format_vars: docstring = ast.get_docstring(wrapped_node) if docstring: - wrapped_node.body[0].value.s = docstring.format( - **self.sync_docstring_format_vars + wrapped_node.body[0].value = ast.Constant( + value=docstring.format(**self.sync_docstring_format_vars) ) return wrapped_node diff --git a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py index 7fb2a794b..e85382ddd 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py @@ -244,11 +244,7 @@ class _Sync_Impl(metaclass=MappingMeta): @classmethod def Mock(cls, *args, **kwargs): - # try/except added for compatibility with python < 3.8 - try: - from unittest.mock import Mock - except ImportError: # pragma: NO COVER - from mock import Mock # type: ignore + from unittest.mock import Mock return Mock(*args, **kwargs) @staticmethod diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 6c49ca0da..21f318be3 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -33,11 +33,10 @@ from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule # try/except added for compatibility with python < 3.8 +from unittest import mock try: - from unittest import mock from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER - import mock # type: ignore from mock import AsyncMock # type: ignore VENEER_HEADER_REGEX = re.compile( diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index 4394889d5..6deb5f137 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -20,10 +20,17 @@ import time import queue import functools +import sys from google import api_core from google.cloud.bigtable.data._sync.cross_sync.cross_sync import CrossSync, T -from unittest import mock +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # type: ignore +except ImportError: # pragma: NO COVER + import mock # type: ignore + from mock import AsyncMock # type: ignore class TestCrossSync: async def async_iter(self, in_list): @@ -86,8 +93,9 @@ async def test_Mock(self, cs_sync, cs_async): """ Test Mock class in both sync and async versions """ - assert isinstance(cs_async.Mock(), mock.AsyncMock) - assert isinstance(cs_sync.Mock(), mock.Mock) + import unittest.mock + assert isinstance(cs_async.Mock(), AsyncMock) + assert isinstance(cs_sync.Mock(), unittest.mock.Mock) # test with return value assert await cs_async.Mock(return_value=1)() == 1 assert cs_sync.Mock(return_value=1)() == 1 @@ -217,7 +225,7 @@ async def coro(i): return_exceptions = object() partials = [functools.partial(coro, i) for i in range(5)] - with mock.patch.object(asyncio, "gather", mock.AsyncMock()) as gather: + with mock.patch.object(asyncio, "gather", AsyncMock()) as gather: await cs_async.gather_partials( partials, return_exceptions=return_exceptions ) @@ -265,7 +273,7 @@ async def test_event_wait_async(self, cs_async, break_early): """ With no timeout, call event.wait() with no arguments """ - event = mock.AsyncMock() + event = AsyncMock() await cs_async.event_wait(event, async_break_early=break_early) event.wait.assert_called_once_with() @@ -278,7 +286,7 @@ async def test_event_wait_async_with_timeout(self, cs_async): event = mock.Mock() event.wait.return_value = object() timeout = object() - with mock.patch.object(asyncio, "wait_for", mock.AsyncMock()) as wait_for: + with mock.patch.object(asyncio, "wait_for", AsyncMock()) as wait_for: await cs_async.event_wait(event, timeout=timeout) assert wait_for.await_count == 1 assert wait_for.call_count == 1 @@ -302,7 +310,7 @@ async def test_event_wait_async_already_set(self, cs_async, break_early): """ if event is already set, return immediately """ - event = mock.AsyncMock() + event = AsyncMock() event.is_set = lambda: True start_time = time.monotonic() await cs_async.event_wait(event, async_break_early=break_early) @@ -318,7 +326,7 @@ async def test_event_wait_no_break_early(self, cs_async): event = mock.Mock() event.is_set.return_value = False timeout = object() - with mock.patch.object(asyncio, "sleep", mock.AsyncMock()) as sleep: + with mock.patch.object(asyncio, "sleep", AsyncMock()) as sleep: await cs_async.event_wait(event, timeout=timeout, async_break_early=False) sleep.assert_called_once_with(timeout) @@ -382,6 +390,7 @@ async def test_create_task_async_passthrough(self, cs_async): create_task.assert_called_once_with(coro_fn.return_value) coro_fn.assert_called_once_with(*args, **kwargs) + @pytest.mark.skipif(sys.version_info < (3, 8), reason="Task names require python 3.8") @pytest.mark.asyncio async def test_create_task_async_with_name(self, cs_async): """ @@ -407,7 +416,7 @@ async def test_yield_to_event_loop_async(self, cs_async): """ should call await asyncio.sleep(0) """ - with mock.patch.object(asyncio, "sleep", mock.AsyncMock()) as sleep: + with mock.patch.object(asyncio, "sleep", AsyncMock()) as sleep: await cs_async.yield_to_event_loop() sleep.assert_called_once_with(0) From 11ab1d032d44d7a2eea12a994462833fb7fefb7b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 6 Sep 2024 15:40:42 -0700 Subject: [PATCH 255/360] fixed lint --- google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py | 1 + tests/unit/data/_async/test_client.py | 1 + tests/unit/data/_sync/test_cross_sync.py | 6 +++++- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py index e85382ddd..24776aabd 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py @@ -245,6 +245,7 @@ class _Sync_Impl(metaclass=MappingMeta): @classmethod def Mock(cls, *args, **kwargs): from unittest.mock import Mock + return Mock(*args, **kwargs) @staticmethod diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 21f318be3..88bc6ca7b 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -34,6 +34,7 @@ # try/except added for compatibility with python < 3.8 from unittest import mock + try: from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index 6deb5f137..9c3022f6c 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -32,6 +32,7 @@ import mock # type: ignore from mock import AsyncMock # type: ignore + class TestCrossSync: async def async_iter(self, in_list): for i in in_list: @@ -94,6 +95,7 @@ async def test_Mock(self, cs_sync, cs_async): Test Mock class in both sync and async versions """ import unittest.mock + assert isinstance(cs_async.Mock(), AsyncMock) assert isinstance(cs_sync.Mock(), unittest.mock.Mock) # test with return value @@ -390,7 +392,9 @@ async def test_create_task_async_passthrough(self, cs_async): create_task.assert_called_once_with(coro_fn.return_value) coro_fn.assert_called_once_with(*args, **kwargs) - @pytest.mark.skipif(sys.version_info < (3, 8), reason="Task names require python 3.8") + @pytest.mark.skipif( + sys.version_info < (3, 8), reason="Task names require python 3.8" + ) @pytest.mark.asyncio async def test_create_task_async_with_name(self, cs_async): """ From bb9c160fc194fa21cc64c0a9e07e7edc1e376fc2 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 9 Sep 2024 16:27:50 -0700 Subject: [PATCH 256/360] added mutations batcher tests --- tests/system/data/test_system_async.py | 25 +++++++++++++++++++ .../data/_async/test_mutations_batcher.py | 21 ++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 28a89a8e2..e8edeef0f 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -482,6 +482,31 @@ async def test_mutations_batcher_no_flush(self, client, table, temp_rows): assert (await self._retrieve_cell_value(table, row_key)) == start_value assert (await self._retrieve_cell_value(table, row_key2)) == start_value + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + @CrossSync.pytest + async def test_mutations_batcher_large_batch(self, client, table, temp_rows): + """ + test batcher with large batch of mutations + """ + from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell + add_mutation = SetCell(family=TEST_FAMILY, qualifier=b"test-qualifier", new_value=b"a") + row_mutations = [] + for i in range(50_000): + row_key = uuid.uuid4().hex.encode() + row_mutations.append(RowMutationEntry(row_key, [add_mutation])) + # append row key for eventual deletion + temp_rows.rows.append(row_key) + + async with table.mutations_batcher() as batcher: + for mutation in row_mutations: + await batcher.append(mutation) + # ensure cell is updated + assert len(batcher._staged_entries) == 0 + @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("table") @pytest.mark.parametrize( diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 2c61d005a..6e5949575 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -1217,3 +1217,24 @@ async def test_customizable_retryable_errors( retry_call_args = retry_fn_mock.call_args_list[0].args # output of if_exception_type should be sent in to retry constructor assert retry_call_args[1] is expected_predicate + + @CrossSync.pytest + async def test_large_batch_write(self): + """ + Test that a large batch of mutations can be written + """ + import math + num_mutations = 10_000 + flush_limit = 1000 + mutations = [self._make_mutation(count=1, size=1)] * num_mutations + async with self._make_one(flush_limit_mutation_count=flush_limit) as instance: + operation_mock = mock.Mock() + rpc_call_mock = CrossSync.Mock() + operation_mock().start = rpc_call_mock + CrossSync._MutateRowsOperation = operation_mock + for m in mutations: + await instance.append(m) + expected_calls = math.ceil(num_mutations / flush_limit) + assert rpc_call_mock.call_count == expected_calls + assert instance._entries_processed_since_last_raise == num_mutations + assert len(instance._staged_entries) == 0 From 06e04253cb7c303a496b8c78707144b40f574ca7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 9 Sep 2024 16:28:23 -0700 Subject: [PATCH 257/360] batcher uses multuple sync executors --- .../bigtable/data/_async/mutations_batcher.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index d075768df..76960254e 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -247,13 +247,20 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - self._sync_executor = ( + # used by sync class to run mutate_rows operations + self._sync_rpc_executor = ( concurrent.futures.ThreadPoolExecutor(max_workers=8) if not CrossSync.is_async else None ) + # used by sync class to manage flush_internal tasks + self._sync_flush_executor = ( + concurrent.futures.ThreadPoolExecutor(max_workers=1) + if not CrossSync.is_async + else None + ) self._flush_timer = CrossSync.create_task( - self._timer_routine, flush_interval, sync_executor=self._sync_executor + self._timer_routine, flush_interval, sync_executor=self._sync_flush_executor ) self._flush_jobs: set[CrossSync.Future[None]] = set() # MutationExceptionGroup reports number of successful entries along with failures @@ -333,7 +340,7 @@ def _schedule_flush(self) -> CrossSync.Future[None] | None: entries, self._staged_entries = self._staged_entries, [] self._staged_count, self._staged_bytes = 0, 0 new_task = CrossSync.create_task( - self._flush_internal, entries, sync_executor=self._sync_executor + self._flush_internal, entries, sync_executor=self._sync_flush_executor ) if not new_task.done(): self._flush_jobs.add(new_task) @@ -355,7 +362,7 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): self._flow_control.add_to_flow(new_entries) ): batch_task = CrossSync.create_task( - self._execute_mutate_rows, batch, sync_executor=self._sync_executor + self._execute_mutate_rows, batch, sync_executor=self._sync_rpc_executor ) in_process_requests.append(batch_task) # wait for all inflight requests to complete @@ -477,11 +484,14 @@ async def close(self): self._closed.set() self._flush_timer.cancel() self._schedule_flush() + # shut down executors + if self._sync_rpc_executor: + with self._sync_rpc_executor: + self._sync_rpc_executor.shutdown(wait=True) + if self._sync_flush_executor: + with self._sync_flush_executor: + self._sync_flush_executor.shutdown(wait=True) CrossSync.rm_aio(await CrossSync.wait([*self._flush_jobs, self._flush_timer])) - # shut down executor - if self._sync_executor: - with self._sync_executor: - self._sync_executor.shutdown(wait=True) atexit.unregister(self._on_exit) # raise unreported exceptions self._raise_exceptions() From 13abfd452296cf1a089088ec8ccc27127834e4de Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 9 Sep 2024 16:52:21 -0700 Subject: [PATCH 258/360] implemented wait manually --- .../data/_sync/cross_sync/cross_sync.py | 27 +++++- tests/unit/data/_sync/test_cross_sync.py | 93 ++++++++++++++++++- 2 files changed, 117 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py index 24776aabd..dceff5a62 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py @@ -80,7 +80,6 @@ class CrossSync(metaclass=MappingMeta): # provide aliases for common async functions and types sleep = asyncio.sleep - wait = asyncio.wait retry_target = retries.retry_target_async retry_target_stream = retries.retry_target_stream_async Retry = retries.AsyncRetry @@ -147,6 +146,20 @@ async def gather_partials( *awaitable_list, return_exceptions=return_exceptions ) + @staticmethod + async def wait( + futures: Sequence[CrossSync.Future[T]], timeout: float | None = None + ) -> tuple[set[CrossSync.Future[T]], set[CrossSync.Future[T]]]: + """ + abstraction over asyncio.wait + + Return: + - a tuple of (done, pending) sets of futures + """ + if not futures: + return set(), set() + return await asyncio.wait(futures, timeout=timeout) + @staticmethod async def event_wait( event: CrossSync.Event, @@ -224,7 +237,6 @@ class _Sync_Impl(metaclass=MappingMeta): is_async = False sleep = time.sleep - wait = concurrent.futures.wait next = next retry_target = retries.retry_target retry_target_stream = retries.retry_target_stream @@ -279,6 +291,17 @@ def gather_partials( results_list.append(future.result()) return results_list + @staticmethod + def wait( + futures: Sequence[CrossSync._Sync_Impl.Future[T]], + timeout: float | None = None, + ) -> tuple[ + set[CrossSync._Sync_Impl.Future[T]], set[CrossSync._Sync_Impl.Future[T]] + ]: + if not futures: + return set(), set() + return concurrent.futures.wait(futures, timeout=timeout) + @staticmethod def create_task( fn: Callable[..., T], diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index 9c3022f6c..903207694 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -51,7 +51,6 @@ def cs_async(self): [ ("is_async", True, False), ("sleep", asyncio.sleep, time.sleep), - ("wait", asyncio.wait, concurrent.futures.wait), ( "retry_target", api_core.retry.retry_target_async, @@ -237,6 +236,98 @@ async def coro(i): for coro in found_args: await coro + def test_wait(self, cs_sync): + """ + Test sync version of CrossSync.wait() + + If future is complete, it should be in the first (complete) set + """ + future = concurrent.futures.Future() + future.set_result(1) + s1, s2 = cs_sync.wait([future]) + assert s1 == {future} + assert s2 == set() + + def test_wait_timeout(self, cs_sync): + """ + If timeout occurs, future should be in the second (incomplete) set + """ + future = concurrent.futures.Future() + timeout = 0.1 + start_time = time.monotonic() + s1, s2 = cs_sync.wait([future], timeout) + end_time = time.monotonic() + assert abs((end_time - start_time) - timeout) < 0.01 + assert s1 == set() + assert s2 == {future} + + def test_wait_passthrough(self, cs_sync): + """ + sync version of CrossSync.wait() should pass through to concurrent.futures.wait() + """ + future = object() + timeout = object() + with mock.patch.object(concurrent.futures, "wait", mock.Mock()) as wait: + result = cs_sync.wait([future], timeout) + assert wait.call_count == 1 + assert wait.call_args == (([future],), {"timeout": timeout}) + assert result == wait.return_value + + def test_wait_empty_input(self, cs_sync): + """ + If no futures are provided, return empty sets + """ + s1, s2 = cs_sync.wait([]) + assert s1 == set() + assert s2 == set() + + @pytest.mark.asyncio + async def test_wait_async(self, cs_async): + """ + Test async version of CrossSync.wait() + """ + future = asyncio.Future() + future.set_result(1) + s1, s2 = await cs_async.wait([future]) + assert s1 == {future} + assert s2 == set() + + @pytest.mark.asyncio + async def test_wait_async_timeout(self, cs_async): + """ + If timeout occurs, future should be in the second (incomplete) set + """ + future = asyncio.Future() + timeout = 0.1 + start_time = time.monotonic() + s1, s2 = await cs_async.wait([future], timeout) + end_time = time.monotonic() + assert abs((end_time - start_time) - timeout) < 0.01 + assert s1 == set() + assert s2 == {future} + + @pytest.mark.asyncio + async def test_wait_async_passthrough(self, cs_async): + """ + async version of CrossSync.wait() should pass through to asyncio.wait() + """ + future = object() + timeout = object() + with mock.patch.object(asyncio, "wait", AsyncMock()) as wait: + result = await cs_async.wait([future], timeout) + assert wait.call_count == 1 + assert wait.call_args == (([future],), {"timeout": timeout}) + assert result == wait.return_value + + @pytest.mark.asyncio + async def test_wait_async_empty_input(self, cs_async): + """ + If no futures are provided, return empty sets + """ + s1, s2 = await cs_async.wait([]) + assert s1 == set() + assert s2 == set() + def test_event_wait_passthrough(self, cs_sync): """ Test sync version of CrossSync.event_wait() From 6ceb9133622887c62479307b6ca396890f4e5951 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 9 Sep 2024 16:53:37 -0700 Subject: [PATCH 259/360] regenerated files --- .../bigtable/data/_sync/mutations_batcher.py | 22 +++++++++++++------ tests/system/data/test_system.py | 22 +++++++++++++++++++ .../unit/data/_sync/test_mutations_batcher.py | 19 ++++++++++++++++ 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index dfd889d0e..8d15733d3 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -217,13 +217,18 @@ def __init__( if flush_limit_mutation_count is not None else float("inf") ) - self._sync_executor = ( + self._sync_rpc_executor = ( concurrent.futures.ThreadPoolExecutor(max_workers=8) if not CrossSync._Sync_Impl.is_async else None ) + self._sync_flush_executor = ( + concurrent.futures.ThreadPoolExecutor(max_workers=1) + if not CrossSync._Sync_Impl.is_async + else None + ) self._flush_timer = CrossSync._Sync_Impl.create_task( - self._timer_routine, flush_interval, sync_executor=self._sync_executor + self._timer_routine, flush_interval, sync_executor=self._sync_flush_executor ) self._flush_jobs: set[CrossSync._Sync_Impl.Future[None]] = set() self._entries_processed_since_last_raise: int = 0 @@ -286,7 +291,7 @@ def _schedule_flush(self) -> CrossSync._Sync_Impl.Future[None] | None: entries, self._staged_entries = (self._staged_entries, []) self._staged_count, self._staged_bytes = (0, 0) new_task = CrossSync._Sync_Impl.create_task( - self._flush_internal, entries, sync_executor=self._sync_executor + self._flush_internal, entries, sync_executor=self._sync_flush_executor ) if not new_task.done(): self._flush_jobs.add(new_task) @@ -304,7 +309,7 @@ def _flush_internal(self, new_entries: list[RowMutationEntry]): ] = [] for batch in self._flow_control.add_to_flow(new_entries): batch_task = CrossSync._Sync_Impl.create_task( - self._execute_mutate_rows, batch, sync_executor=self._sync_executor + self._execute_mutate_rows, batch, sync_executor=self._sync_rpc_executor ) in_process_requests.append(batch_task) found_exceptions = self._wait_for_batch_results(*in_process_requests) @@ -402,10 +407,13 @@ def close(self): self._closed.set() self._flush_timer.cancel() self._schedule_flush() + if self._sync_rpc_executor: + with self._sync_rpc_executor: + self._sync_rpc_executor.shutdown(wait=True) + if self._sync_flush_executor: + with self._sync_flush_executor: + self._sync_flush_executor.shutdown(wait=True) CrossSync._Sync_Impl.wait([*self._flush_jobs, self._flush_timer]) - if self._sync_executor: - with self._sync_executor: - self._sync_executor.shutdown(wait=True) atexit.unregister(self._on_exit) self._raise_exceptions() diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py index 32e24463b..30dc052c9 100644 --- a/tests/system/data/test_system.py +++ b/tests/system/data/test_system.py @@ -373,6 +373,28 @@ def test_mutations_batcher_no_flush(self, client, table, temp_rows): assert self._retrieve_cell_value(table, row_key) == start_value assert self._retrieve_cell_value(table, row_key2) == start_value + @pytest.mark.usefixtures("client") + @pytest.mark.usefixtures("table") + @CrossSync._Sync_Impl.Retry( + predicate=retry.if_exception_type(ClientError), initial=1, maximum=5 + ) + def test_mutations_batcher_large_batch(self, client, table, temp_rows): + """test batcher with large batch of mutations""" + from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell + + add_mutation = SetCell( + family=TEST_FAMILY, qualifier=b"test-qualifier", new_value=b"a" + ) + row_mutations = [] + for i in range(50000): + row_key = uuid.uuid4().hex.encode() + row_mutations.append(RowMutationEntry(row_key, [add_mutation])) + temp_rows.rows.append(row_key) + with table.mutations_batcher() as batcher: + for mutation in row_mutations: + batcher.append(mutation) + assert len(batcher._staged_entries) == 0 + @pytest.mark.usefixtures("client") @pytest.mark.usefixtures("table") @pytest.mark.parametrize( diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py index fe3792293..98d23da85 100644 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -1081,3 +1081,22 @@ def test_customizable_retryable_errors(self, input_retryables, expected_retryabl ) retry_call_args = retry_fn_mock.call_args_list[0].args assert retry_call_args[1] is expected_predicate + + def test_large_batch_write(self): + """Test that a large batch of mutations can be written""" + import math + + num_mutations = 10000 + flush_limit = 1000 + mutations = [self._make_mutation(count=1, size=1)] * num_mutations + with self._make_one(flush_limit_mutation_count=flush_limit) as instance: + operation_mock = mock.Mock() + rpc_call_mock = CrossSync._Sync_Impl.Mock() + operation_mock().start = rpc_call_mock + CrossSync._Sync_Impl._MutateRowsOperation = operation_mock + for m in mutations: + instance.append(m) + expected_calls = math.ceil(num_mutations / flush_limit) + assert rpc_call_mock.call_count == expected_calls + assert instance._entries_processed_since_last_raise == num_mutations + assert len(instance._staged_entries) == 0 From 1318ab6645788d318018552cf42e3d2ea5b54aeb Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 9 Sep 2024 17:12:33 -0700 Subject: [PATCH 260/360] fixed order of executor close --- google/cloud/bigtable/data/_async/mutations_batcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 76960254e..443730d08 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -485,12 +485,12 @@ async def close(self): self._flush_timer.cancel() self._schedule_flush() # shut down executors - if self._sync_rpc_executor: - with self._sync_rpc_executor: - self._sync_rpc_executor.shutdown(wait=True) if self._sync_flush_executor: with self._sync_flush_executor: self._sync_flush_executor.shutdown(wait=True) + if self._sync_rpc_executor: + with self._sync_rpc_executor: + self._sync_rpc_executor.shutdown(wait=True) CrossSync.rm_aio(await CrossSync.wait([*self._flush_jobs, self._flush_timer])) atexit.unregister(self._on_exit) # raise unreported exceptions From fd4b779f9ed9391a9644f7643acaf6c3a38ebc79 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 9 Sep 2024 17:13:22 -0700 Subject: [PATCH 261/360] updated generated code --- google/cloud/bigtable/data/_sync/mutations_batcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 8d15733d3..9b9c750da 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -407,12 +407,12 @@ def close(self): self._closed.set() self._flush_timer.cancel() self._schedule_flush() - if self._sync_rpc_executor: - with self._sync_rpc_executor: - self._sync_rpc_executor.shutdown(wait=True) if self._sync_flush_executor: with self._sync_flush_executor: self._sync_flush_executor.shutdown(wait=True) + if self._sync_rpc_executor: + with self._sync_rpc_executor: + self._sync_rpc_executor.shutdown(wait=True) CrossSync._Sync_Impl.wait([*self._flush_jobs, self._flush_timer]) atexit.unregister(self._on_exit) self._raise_exceptions() From 8d15f48735ebab76b1c08a8b336a5f99cc058d1b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 9 Sep 2024 17:18:04 -0700 Subject: [PATCH 262/360] changed num workers --- google/cloud/bigtable/data/_async/mutations_batcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 443730d08..1f5eba221 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -255,7 +255,7 @@ def __init__( ) # used by sync class to manage flush_internal tasks self._sync_flush_executor = ( - concurrent.futures.ThreadPoolExecutor(max_workers=1) + concurrent.futures.ThreadPoolExecutor(max_workers=4) if not CrossSync.is_async else None ) From 8bdddf488767b9ad3620baab4121a7e1f0ee2704 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 10 Sep 2024 14:28:00 -0700 Subject: [PATCH 263/360] added noxfile step to generate sync --- noxfile.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/noxfile.py b/noxfile.py index 1931ac3b1..2f417ca4f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -552,3 +552,11 @@ def prerelease_deps(session, protobuf_implementation): "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, }, ) + +@nox.session(python="3.11") +def generate_sync(session): + """ + Re-generate sync files for the library from CrossSync-annotated async source + """ + session.install("black", "autoflake") + session.run("python", ".cross_sync/generate.py", ".") From bb2a53952a7a864f4ea1858641408d8c41502945 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 10 Sep 2024 14:28:58 -0700 Subject: [PATCH 264/360] added test for generated code --- tests/unit/data/test_sync_up_to_date.py | 51 +++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 tests/unit/data/test_sync_up_to_date.py diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py new file mode 100644 index 000000000..cb616f130 --- /dev/null +++ b/tests/unit/data/test_sync_up_to_date.py @@ -0,0 +1,51 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import hashlib +import pytest +from difflib import unified_diff + +# add cross_sync to path +test_dir_name = os.path.dirname(__file__) +repo_root = os.path.join(test_dir_name, "..", "..", "..") +cross_sync_path = os.path.join(repo_root, ".cross_sync") +sys.path.append(cross_sync_path) + +from generate import convert_files_in_dir # noqa: E402 + + +@pytest.mark.parametrize( + "artifact", convert_files_in_dir(repo_root), ids=lambda a: a.file_path +) +@pytest.mark.skipif(sys.version_info < (3, 11), reason="generation uses python3.11") +def test_sync_up_to_date(artifact): + """ + Generate a fresh copy of each cross_sync file, and compare hashes with the existing file. + + If this test fails, run `nox -s generate_sync` to update the sync files. + """ + path = artifact.file_path + new_render = artifact.render() + found_render = open(path).read() + # compare by content + diff = unified_diff( + found_render.splitlines(), new_render.splitlines(), lineterm="" + ) + diff_str = "\n".join(diff) + assert not diff_str, f"Found differences:\n{diff_str}" + # compare by hash + new_hash = hashlib.md5(new_render.encode()).hexdigest() + found_hash = hashlib.md5(found_render.encode()).hexdigest() + assert new_hash == found_hash, f"md5 mismatch for {path}" From 78c1405a41703c1ee8e7b3d09f5bd42432a1632f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 10 Sep 2024 14:57:14 -0700 Subject: [PATCH 265/360] added missing unit test deps --- noxfile.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/noxfile.py b/noxfile.py index 2f417ca4f..ce6ddcde2 100644 --- a/noxfile.py +++ b/noxfile.py @@ -41,6 +41,8 @@ "pytest", "pytest-cov", "pytest-asyncio", + "black", + "autoflake", ] UNIT_TEST_EXTERNAL_DEPENDENCIES: List[str] = [] UNIT_TEST_LOCAL_DEPENDENCIES: List[str] = [] From b4608cb854ebcc763fa02fadabc6e8c609dd39f5 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 10 Sep 2024 15:05:20 -0700 Subject: [PATCH 266/360] run generate_sync from owlbot --- owlbot.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/owlbot.py b/owlbot.py index 0ec4cd61c..c67ae5af8 100644 --- a/owlbot.py +++ b/owlbot.py @@ -171,3 +171,9 @@ def insert(file, before_line, insert_line, after_line, escape=None): INSTALL_LIBRARY_FROM_SOURCE = False""") s.shell.run(["nox", "-s", "blacken"], hide_output=False) + +# ---------------------------------------------------------------------------- +# Run Cross Sync +# ---------------------------------------------------------------------------- + +s.shell.run(["nox", "-s", "generate_sync"]) From 7dc9a2b6606ecf56134203d3174c0fc39b5d2599 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 10 Sep 2024 17:02:38 -0700 Subject: [PATCH 267/360] simplified file processing --- .cross_sync/README.md | 4 +- .cross_sync/generate.py | 71 +++------- .cross_sync/transformers.py | 121 ++++++------------ .../data/_sync/cross_sync/_decorators.py | 22 +--- ...ync_classes.yaml => cross_sync_files.yaml} | 77 ++++++----- .../system/cross_sync/test_cross_sync_e2e.py | 14 +- .../data/_sync/test_cross_sync_decorators.py | 36 ++---- 7 files changed, 139 insertions(+), 206 deletions(-) rename tests/system/cross_sync/test_cases/{cross_sync_classes.yaml => cross_sync_files.yaml} (70%) diff --git a/.cross_sync/README.md b/.cross_sync/README.md index 563fccb3b..0d43c1bb4 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -63,8 +63,8 @@ CrossSync provides a set of annotations to mark up async classes, to guide the g ### Code Generation Generation can be initiated using `python .cross_sync/generate.py .` -from the root of the project. This will find all classes with the `@CrossSync.export_sync` annotation -in both `/google` and `/tests` directories, and save them to their specified output paths +from the root of the project. This will find all classes with the `__CROSS_SYNC_OUTPUT__ = "path/to/output"` +annotation, and generate a sync version of classes marked with `@CrossSync.export_sync` at the output path. ## Architecture diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index d93838e59..fea4f04f3 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -14,56 +14,30 @@ from __future__ import annotations from typing import Sequence import ast -from dataclasses import dataclass, field """ Entrypoint for initiating an async -> sync conversion using CrossSync Finds all python files rooted in a given directory, and uses -transformers.CrossSyncClassDecoratorHandler to handle any CrossSync class -decorators found in the files. +transformers.CrossSyncFileHandler to handle any files marked with +__CROSS_SYNC_OUTPUT__ """ -@dataclass class CrossSyncOutputFile: - """ - Represents an output file location. - Multiple decorated async classes may point to the same output location for - their generated sync code. This class holds all the information needed to - write the output file to disk. - """ + def __init__(self, file_path: str, ast_tree): + self.file_path = file_path + self.tree = ast_tree - # The path to the output file - file_path: str - # The import headers to write to the top of the output file - # will be populated when CrossSync.export_sync(include_file_imports=True) - imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = field( - default_factory=list - ) - # The set of sync ast.ClassDef nodes to write to the output file - converted_classes: list[ast.ClassDef] = field(default_factory=list) - # the set of classes contained in the file. Used to prevent duplicates - contained_classes: set[str] = field(default_factory=set) - # the set of mypy error codes to ignore at the file level - # configured using CrossSync.export_sync(mypy_ignore=["error_code"]) - mypy_ignore: list[str] = field(default_factory=list) - - def __hash__(self): - return hash(self.file_path) - - def __repr__(self): - return f"CrossSyncOutputFile({self.file_path}, classes={[c.name for c in self.converted_classes]})" - - def render(self, with_black=True, save_to_disk=False) -> str: + def render(self, with_black=True, save_to_disk: bool = False) -> str: """ - Render the output file as a string. + Render the file to a string, and optionally save to disk Args: with_black: whether to run the output through black before returning save_to_disk: whether to write the output to the file path """ - full_str = ( + header = ( "# Copyright 2024 Google LLC\n" "#\n" '# Licensed under the Apache License, Version 2.0 (the "License");\n' @@ -80,13 +54,7 @@ def render(self, with_black=True, save_to_disk=False) -> str: "#\n" "# This file is automatically generated by CrossSync. Do not edit manually.\n" ) - if self.mypy_ignore: - full_str += ( - f'\n# mypy: disable-error-code="{",".join(self.mypy_ignore)}"\n\n' - ) - full_str += "\n".join([ast.unparse(node) for node in self.imports]) # type: ignore - full_str += "\n\n" - full_str += "\n".join([ast.unparse(node) for node in self.converted_classes]) # type: ignore + full_str = header + ast.unparse(self.converted) if with_black: import black # type: ignore import autoflake # type: ignore @@ -96,30 +64,33 @@ def render(self, with_black=True, save_to_disk=False) -> str: mode=black.FileMode(), ) if save_to_disk: - # create parent paths if needed import os - os.makedirs(os.path.dirname(self.file_path), exist_ok=True) - with open(self.file_path, "w") as f: + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + with open(self.output_path, "w") as f: f.write(full_str) return full_str def convert_files_in_dir(directory: str) -> set[CrossSyncOutputFile]: import glob - from transformers import CrossSyncClassDecoratorHandler + from transformers import CrossSyncFileHandler # find all python files in the directory files = glob.glob(directory + "/**/*.py", recursive=True) # keep track of the output files pointed to by the annotated classes artifacts: set[CrossSyncOutputFile] = set() + file_transformer = CrossSyncFileHandler() # run each file through ast transformation to find all annotated classes - for file in files: - converter = CrossSyncClassDecoratorHandler(file) - new_outputs = converter.convert_file(artifacts) - artifacts.update(new_outputs) + for file_path in files: + file = open(file_path).read() + converted_tree = file_transformer.visit(ast.parse(file)) + if converted_tree is not None: + # contains __CROSS_SYNC_OUTPUT__ annotation + artifacts.add(CrossSyncOutputFile(file_path, converted_tree)) # return set of output artifacts return artifacts + def save_artifacts(artifacts: Sequence[CrossSyncOutputFile]): for a in artifacts: a.render(save_to_disk=True) @@ -130,5 +101,5 @@ def save_artifacts(artifacts: Sequence[CrossSyncOutputFile]): search_root = sys.argv[1] outputs = convert_files_in_dir(search_root) - print(f"Generated {len(outputs)} artifacts: {[a.file_path for a in outputs]}") + print(f"Generated {len(outputs)} artifacts: {[a.file_name for a in outputs]}") save_artifacts(outputs) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index b6a34c690..5afef0d41 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -32,7 +32,6 @@ # add cross_sync to path sys.path.append("google/cloud/bigtable/data/_sync/cross_sync") from _decorators import AstDecorator, ExportSync -from generate import CrossSyncOutputFile class SymbolReplacer(ast.NodeTransformer): @@ -227,96 +226,54 @@ def visit_AsyncFunctionDef(self, node): raise ValueError(f"node {node.name} failed") from e -class CrossSyncClassDecoratorHandler(ast.NodeTransformer): +class CrossSyncFileHandler(ast.NodeTransformer): """ - Visits each class in the file, and if it has a CrossSync decorator, it will be transformed. - - Uses CrossSyncMethodDecoratorHandler to visit and (potentially) convert each method in the class + Visit each file, and process CrossSync classes if found """ - def __init__(self, file_path): - self.in_path = file_path - self._artifact_dict: dict[str, CrossSyncOutputFile] = {} - self.imports: list[ast.Import | ast.ImportFrom | ast.Try | ast.If] = [] - self.cross_sync_symbol_transformer = SymbolReplacer( - {"CrossSync": "CrossSync._Sync_Impl"} - ) - self.cross_sync_method_handler = CrossSyncMethodDecoratorHandler() - - def convert_file( - self, artifacts: set[CrossSyncOutputFile] | None = None - ) -> set[CrossSyncOutputFile]: - """ - Called to run a file through the ast transformer. - If the file contains any classes marked with CrossSync.export_sync, the - classes will be processed according to the decorator arguments, and - a set of CrossSyncOutputFile objects will be returned for each output file. - - If no CrossSync annotations are found, no changes will occur and an - empty set will be returned - """ - tree = ast.parse(open(self.in_path).read()) - self._artifact_dict = {f.file_path: f for f in artifacts or []} - self.imports = self._get_imports(tree) - self.visit(tree) - # return set of new artifacts - return set(self._artifact_dict.values()).difference(artifacts or []) + @staticmethod + def _find_cs_output(node): + for i, n in enumerate(node.body): + if isinstance(n, ast.Assign): + for target in n.targets: + if isinstance(target, ast.Name) and target.id == "__CROSS_SYNC_OUTPUT__": + # keep the output path + # remove the statement + node.body.pop(i) + return n.value.value + ".py" + + def visit_Module(self, node): + # look for __CROSS_SYNC_OUTPUT__ Assign statement + self.output_path = self._find_cs_output(node) + if self.output_path: + # if found, process the file + return self.generic_visit(node) + else: + # not cross_sync file. Return None + return None def visit_ClassDef(self, node): """ Called for each class in file. If class has a CrossSync decorator, it will be transformed - according to the decorator arguments. Otherwise, no changes will occur - - Uses a set of CrossSyncOutputFile objects to store the transformed classes - and avoid duplicate writes + according to the decorator arguments. Otherwise, class is returned unchanged """ - try: - converted = None - for decorator in node.decorator_list: - try: - handler = AstDecorator.get_for_node(decorator) - if isinstance(handler, ExportSync): - # find the path to write the sync class to - out_file = "/".join(handler.path.rsplit(".")[:-1]) + ".py" - sync_cls_name = handler.path.rsplit(".", 1)[-1] - # find the artifact file for the save location - output_artifact = self._artifact_dict.get( - out_file, CrossSyncOutputFile(out_file) - ) - # write converted class details if not already present - if sync_cls_name not in output_artifact.contained_classes: - # transformation is handled in sync_ast_transform method of the decorator - converted = handler.sync_ast_transform(node, globals()) - output_artifact.converted_classes.append(converted) - # handle file-level mypy ignores - mypy_ignores = [ - s - for s in handler.mypy_ignore - if s not in output_artifact.mypy_ignore - ] - output_artifact.mypy_ignore.extend(mypy_ignores) - # handle file-level imports - if not output_artifact.imports and handler.include_file_imports: - output_artifact.imports = self.imports - self._artifact_dict[out_file] = output_artifact - except ValueError: - continue - return converted - except ValueError as e: - raise ValueError(f"failed for class: {node.name}") from e + for decorator in node.decorator_list: + try: + handler = AstDecorator.get_for_node(decorator) + if isinstance(handler, ExportSync): + # transformation is handled in sync_ast_transform method of the decorator + return handler.sync_ast_transform(node, globals()) + except ValueError: + # not cross_sync decorator + continue + # cross_sync decorator not found + return node - def _get_imports( - self, tree: ast.Module - ) -> list[ast.Import | ast.ImportFrom | ast.Try | ast.If]: + def visit_If(self, node): """ - Grab the imports from the top of the file - - raw imports, as well as try and if statements at the top level are included + remove CrossSync.is_async branches from top-level if statements """ - imports = [] - for node in tree.body: - if isinstance(node, (ast.Import, ast.ImportFrom, ast.Try, ast.If)): - imports.append(self.cross_sync_symbol_transformer.visit(node)) - return imports - + if isinstance(node.test, ast.Attribute) and isinstance(node.test.value, ast.Name) and node.test.value.id == "CrossSync" and node.test.attr == "is_async": + return node.orelse + return self.generic_visit(node) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index b1456951b..35219bee1 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: import ast - from typing import Sequence, Callable, Any + from typing import Callable, Any class AstDecorator: @@ -175,25 +175,21 @@ class ExportSync(AstDecorator): Class decorator for marking async classes to be converted to sync classes Args: - path: path to output the generated sync class + sync_name: use a new name for the sync class replace_symbols: a dict of symbols and replacements to use when generating sync class docstring_format_vars: a dict of variables to replace in the docstring - mypy_ignore: set of mypy errors to ignore in the generated file - include_file_imports: if True, include top-level imports from the file in the generated sync class add_mapping_for_name: when given, will add a new attribute to CrossSync, so the original class and its sync version can be accessed from CrossSync. """ def __init__( self, - path: str, + sync_name: str | None = None, *, replace_symbols: dict[str, str] | None = None, docstring_format_vars: dict[str, tuple[str, str]] | None = None, - mypy_ignore: Sequence[str] = (), - include_file_imports: bool = True, add_mapping_for_name: str | None = None, ): - self.path = path + self.sync_name = sync_name self.replace_symbols = replace_symbols docstring_format_vars = docstring_format_vars or {} self.async_docstring_format_vars = { @@ -202,8 +198,6 @@ def __init__( self.sync_docstring_format_vars = { k: v[1] for k, v in docstring_format_vars.items() } - self.mypy_ignore = mypy_ignore - self.include_file_imports = include_file_imports self.add_mapping_for_name = add_mapping_for_name def async_decorator(self): @@ -230,15 +224,11 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): import ast import copy - if not self.path: - raise ValueError( - f"{wrapped_node.name} has no path specified in export_sync decorator" - ) # copy wrapped node wrapped_node = copy.deepcopy(wrapped_node) # update name - sync_cls_name = self.path.rsplit(".", 1)[-1] - wrapped_node.name = sync_cls_name + if self.sync_name: + wrapped_node.name = self.sync_name # strip CrossSync decorators if hasattr(wrapped_node, "decorator_list"): wrapped_node.decorator_list = [ diff --git a/tests/system/cross_sync/test_cases/cross_sync_classes.yaml b/tests/system/cross_sync/test_cases/cross_sync_files.yaml similarity index 70% rename from tests/system/cross_sync/test_cases/cross_sync_classes.yaml rename to tests/system/cross_sync/test_cases/cross_sync_files.yaml index f38335e87..a49b8189e 100644 --- a/tests/system/cross_sync/test_cases/cross_sync_classes.yaml +++ b/tests/system/cross_sync/test_cases/cross_sync_files.yaml @@ -1,15 +1,40 @@ tests: - - description: "No conversion needed" + - description: "No output annotation" before: | - @CrossSync.export_sync(path="example.sync.MyClass") class MyAsyncClass: async def my_method(self): pass transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler + after: null + + - description: "CrossSync.export_sync with default sync_name" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.export_sync + class MyClass: + async def my_method(self): + pass + + transformers: + - name: CrossSyncFileHandler + after: | + class MyClass: + + async def my_method(self): + pass + + - description: "CrossSync.export_sync with custom sync_name" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.export_sync(sync_name="MyClass") + class MyAsyncClass: + async def my_method(self): + pass + + transformers: + - name: CrossSyncFileHandler after: | class MyClass: @@ -18,8 +43,9 @@ tests: - description: "CrossSync.export_sync with replace_symbols" before: | + __CROSS_SYNC_OUTPUT__ = "out.path" @CrossSync.export_sync( - path="example.sync.MyClass", + sync_name="MyClass", replace_symbols={"AsyncBase": "SyncBase", "ParentA": "ParentB"} ) class MyAsyncClass(ParentA): @@ -27,9 +53,7 @@ tests: self.base = base transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler after: | class MyClass(ParentB): @@ -38,24 +62,24 @@ tests: - description: "CrossSync.export_sync with docstring formatting" before: | + __CROSS_SYNC_OUTPUT__ = "out.path" @CrossSync.export_sync( - path="example.sync.MyClass", + sync_name="MyClass", docstring_format_vars={"type": ("async", "sync")} ) class MyAsyncClass: """This is a {type} class.""" transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler after: | class MyClass: """This is a sync class.""" - description: "CrossSync.export_sync with multiple decorators and methods" before: | - @CrossSync.export_sync(path="example.sync.MyClass") + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.export_sync(sync_name="MyClass") @some_other_decorator class MyAsyncClass: @CrossSync.convert @@ -75,9 +99,7 @@ tests: pass transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler after: | @some_other_decorator class MyClass: @@ -95,7 +117,8 @@ tests: - description: "CrossSync.export_sync with nested classes" before: | - @CrossSync.export_sync(path="example.sync.MyClass", replace_symbols={"AsyncBase": "SyncBase"}) + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.export_sync(sync_name="MyClass", replace_symbols={"AsyncBase": "SyncBase"}) class MyAsyncClass: class NestedAsyncClass: async def nested_method(self, base: AsyncBase): @@ -110,9 +133,7 @@ tests: nested = self.NestedAsyncClass() CrossSync.rm_aio(await nested.nested_method()) transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler after: | class MyClass: @@ -127,8 +148,9 @@ tests: - description: "CrossSync.export_sync with add_mapping" before: | + __CROSS_SYNC_OUTPUT__ = "out.path" @CrossSync.export_sync( - path="example.sync.MyClass", + sync_name="MyClass", add_mapping_for_name="MyClass" ) class MyAsyncClass: @@ -136,9 +158,7 @@ tests: pass transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler after: | @CrossSync._Sync_Impl.add_mapping_decorator("MyClass") class MyClass: @@ -148,7 +168,8 @@ tests: - description: "CrossSync.export_sync with CrossSync calls" before: | - @CrossSync.export_sync(path="example.sync.MyClass") + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.export_sync(sync_name="MyClass") class MyAsyncClass: @CrossSync.convert async def my_method(self): @@ -156,9 +177,7 @@ tests: CrossSync.rm_aio(await CrossSync.yield_to_event_loop()) transformers: - - name: CrossSyncClassDecoratorHandler - args: - file_path: "dummy_path.py" + - name: CrossSyncFileHandler after: | class MyClass: diff --git a/tests/system/cross_sync/test_cross_sync_e2e.py b/tests/system/cross_sync/test_cross_sync_e2e.py index 489e042fe..bd08ed6cb 100644 --- a/tests/system/cross_sync/test_cross_sync_e2e.py +++ b/tests/system/cross_sync/test_cross_sync_e2e.py @@ -15,7 +15,7 @@ AsyncToSync, RmAioFunctions, CrossSyncMethodDecoratorHandler, - CrossSyncClassDecoratorHandler, + CrossSyncFileHandler, ) @@ -42,7 +42,7 @@ def loader(): sys.version_info < (3, 9), reason="ast.unparse requires python3.9 or higher" ) def test_e2e_scenario(test_dict): - before_ast = ast.parse(test_dict["before"]).body[0] + before_ast = ast.parse(test_dict["before"]) got_ast = before_ast for transformer_info in test_dict["transformers"]: # transformer can be passed as a string, or a dict with name and args @@ -54,6 +54,12 @@ def test_e2e_scenario(test_dict): transformer_args = transformer_info.get("args", {}) transformer = transformer_class(**transformer_args) got_ast = transformer.visit(got_ast) - final_str = black.format_str(ast.unparse(got_ast), mode=black.FileMode()) - expected_str = black.format_str(test_dict["after"], mode=black.FileMode()) + if got_ast is None: + final_str = "" + else: + final_str = black.format_str(ast.unparse(got_ast), mode=black.FileMode()) + if test_dict.get("after") is None: + expected_str = "" + else: + expected_str = black.format_str(test_dict["after"], mode=black.FileMode()) assert final_str == expected_str, f"Expected:\n{expected_str}\nGot:\n{final_str}" diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_sync/test_cross_sync_decorators.py index 988d8d113..6c817fd9d 100644 --- a/tests/unit/data/_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_sync/test_cross_sync_decorators.py @@ -45,39 +45,27 @@ def test_ctor_defaults(self): """ Should set default values for path, add_mapping_for_name, and docstring_format_vars """ - with pytest.raises(TypeError) as exc: - self._get_class()() - assert "missing 1 required positional argument" in str(exc.value) - path = object() - instance = self._get_class()(path) - assert instance.path is path + instance = self._get_class()() + assert instance.sync_name is None assert instance.replace_symbols is None - assert instance.mypy_ignore == () - assert instance.include_file_imports is True assert instance.add_mapping_for_name is None assert instance.async_docstring_format_vars == {} assert instance.sync_docstring_format_vars == {} def test_ctor(self): - path = object() + sync_name = "sync_name" replace_symbols = {"a": "b"} docstring_format_vars = {"A": (1, 2)} - mypy_ignore = ("a", "b") - include_file_imports = False add_mapping_for_name = "test_name" instance = self._get_class()( - path=path, + sync_name, replace_symbols=replace_symbols, docstring_format_vars=docstring_format_vars, - mypy_ignore=mypy_ignore, - include_file_imports=include_file_imports, add_mapping_for_name=add_mapping_for_name, ) - assert instance.path is path + assert instance.sync_name is sync_name assert instance.replace_symbols is replace_symbols - assert instance.mypy_ignore is mypy_ignore - assert instance.include_file_imports is include_file_imports assert instance.add_mapping_for_name is add_mapping_for_name assert instance.async_docstring_format_vars == {"A": 1} assert instance.sync_docstring_format_vars == {"A": 2} @@ -87,7 +75,7 @@ def test_class_decorator(self): Should return class being decorated """ unwrapped_class = mock.Mock - wrapped_class = self._get_class().decorator(unwrapped_class, path=1) + wrapped_class = self._get_class().decorator(unwrapped_class, sync_name="s") assert unwrapped_class == wrapped_class def test_class_decorator_adds_mapping(self): @@ -97,11 +85,13 @@ def test_class_decorator_adds_mapping(self): with mock.patch.object(CrossSync, "add_mapping") as add_mapping: mock_cls = mock.Mock # check decoration with no add_mapping - self._get_class().decorator(path=1)(mock_cls) + self._get_class().decorator(sync_name="s")(mock_cls) assert add_mapping.call_count == 0 # check decoration with add_mapping name = "test_name" - self._get_class().decorator(path=1, add_mapping_for_name=name)(mock_cls) + self._get_class().decorator(sync_name="s", add_mapping_for_name=name)( + mock_cls + ) assert add_mapping.call_count == 1 add_mapping.assert_called_once_with(name, mock_cls) @@ -122,13 +112,13 @@ def test_class_decorator_docstring_update(self, docstring, format_vars, expected of the class being decorated """ - @ExportSync.decorator(path=1, docstring_format_vars=format_vars) + @ExportSync.decorator(sync_name="s", docstring_format_vars=format_vars) class Class: __doc__ = docstring assert Class.__doc__ == expected # check internal state - instance = self._get_class()(path=1, docstring_format_vars=format_vars) + instance = self._get_class()(sync_name="s", docstring_format_vars=format_vars) async_replacements = {k: v[0] for k, v in format_vars.items()} sync_replacements = {k: v[1] for k, v in format_vars.items()} assert instance.async_docstring_format_vars == async_replacements @@ -138,7 +128,7 @@ def test_sync_ast_transform_replaces_name(self, globals_mock): """ Should update the name of the new class """ - decorator = self._get_class()("path.to.SyncClass") + decorator = self._get_class()("SyncClass") mock_node = ast.ClassDef(name="AsyncClass", bases=[], keywords=[], body=[]) result = decorator.sync_ast_transform(mock_node, globals_mock) From 18484a21e104bf1fea66e42eca3831ddc37c59c4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 10 Sep 2024 17:22:05 -0700 Subject: [PATCH 268/360] updated CrossSync.export statements --- .../bigtable/data/_async/_mutate_rows.py | 6 ++--- .../cloud/bigtable/data/_async/_read_rows.py | 6 ++--- google/cloud/bigtable/data/_async/client.py | 8 +++--- .../bigtable/data/_async/mutations_batcher.py | 11 +++----- .../_async/execute_query_iterator.py | 6 ++--- tests/system/data/test_system_async.py | 6 +++-- tests/unit/data/_async/test__mutate_rows.py | 5 ++-- tests/unit/data/_async/test__read_rows.py | 4 ++- tests/unit/data/_async/test_client.py | 26 +++++++++---------- .../data/_async/test_mutations_batcher.py | 6 +++-- .../data/_async/test_read_rows_acceptance.py | 4 ++- .../_async/test_query_iterator.py | 6 +++-- 12 files changed, 49 insertions(+), 45 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 69279820e..3d79361bc 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -44,10 +44,10 @@ ) from google.cloud.bigtable.data._sync.client import Table as TableType # type: ignore +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync._mutate_rows" -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync._mutate_rows._MutateRowsOperation", -) + +@CrossSync.export_sync("MutateRowsOperation") class _MutateRowsOperationAsync: """ MutateRowsOperation manages the logic of sending a set of row mutations, diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index aba69a2a7..a2d8195a6 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -42,10 +42,10 @@ else: from google.cloud.bigtable.data._sync.client import Table as TableType # type: ignore +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync._read_rows" -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync._read_rows._ReadRowsOperation", -) + +@CrossSync.export_sync("ReadRowsOperation") class _ReadRowsOperationAsync: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 827db9839..b1bc3fcfa 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -106,9 +106,11 @@ from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync.client" + @CrossSync.export_sync( - path="google.cloud.bigtable.data._sync.client.BigtableDataClient", + sync_name="BigtableDataClient", add_mapping_for_name="DataClient", ) class BigtableDataClientAsync(ClientWithProject): @@ -619,9 +621,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): CrossSync.rm_aio(await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb)) -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync.client.Table", add_mapping_for_name="Table" -) +@CrossSync.export_sync(sync_name="Table", add_mapping_for_name="Table") class TableAsync: """ Main Data API surface diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 443730d08..2bd3f7b35 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -42,11 +42,10 @@ else: from google.cloud.bigtable.data._sync.client import Table as TableType # type: ignore +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync.mutations_batcher" -@CrossSync.export_sync( - path="google.cloud.bigtable.data._sync.mutations_batcher._FlowControl", - add_mapping_for_name="_FlowControl", -) + +@CrossSync.export_sync(sync_name="_FlowControl", add_mapping_for_name="_FlowControl") class _FlowControlAsync: """ Manages flow control for batched mutations. Mutations are registered against @@ -178,9 +177,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] @CrossSync.export_sync( - path="google.cloud.bigtable.data._sync.mutations_batcher.MutationsBatcher", - mypy_ignore=["unreachable"], - add_mapping_for_name="MutationsBatcher", + sync_name="MutationsBatcher", add_mapping_for_name="MutationsBatcher" ) class MutationsBatcherAsync: """ diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index 43f68a926..46389bb06 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -52,10 +52,10 @@ else: from google.cloud.bigtable.data import BigtableDataClient as DataClientType +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data.execute_query._sync.execute_query_iterator" -@CrossSync.export_sync( - path="google.cloud.bigtable.data.execute_query._sync.execute_query_iterator.ExecuteQueryIterator", -) + +@CrossSync.export_sync(sync_name="ExecuteQueryIterator") class ExecuteQueryIteratorAsync: @CrossSync.convert( diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index e8edeef0f..3493bba24 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -27,8 +27,10 @@ from . import TEST_FAMILY, TEST_FAMILY_2 +__CROSS_SYNC_OUTPUT__ = "tests.system.data.test_system" + @CrossSync.export_sync( - path="tests.system.data.test_system.TempRowBuilder", + sync_name="TempRowBuilder", add_mapping_for_name="TempRowBuilder", ) class TempRowBuilderAsync: @@ -77,7 +79,7 @@ async def delete_rows(self): CrossSync.rm_aio(await self.table.client._gapic_client.mutate_rows(request)) -@CrossSync.export_sync(path="tests.system.data.test_system.TestSystem") +@CrossSync.export_sync(sync_name="TestSystem") class TestSystemAsync: @CrossSync.convert @CrossSync.pytest_fixture(scope="session") diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index a307a7008..62f7f88ba 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -27,10 +27,9 @@ except ImportError: # pragma: NO COVER import mock # type: ignore +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test__mutate_rows" -@CrossSync.export_sync( - path="tests.unit.data._sync.test__mutate_rows.TestMutateRowsOperation", -) +@CrossSync.export_sync("TestMutateRowsOperation") class TestMutateRowsOperation: def _target_class(self): return CrossSync._MutateRowsOperation diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 896c17879..ccecab7fd 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -22,8 +22,10 @@ import mock # type: ignore +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test__read_rows" + @CrossSync.export_sync( - path="tests.unit.data._sync.test__read_rows.TestReadRowsOperation", + sync_name="TestReadRowsOperation", ) class TestReadRowsOperationAsync: """ diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index bd5975d6e..431a689b9 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -51,9 +51,11 @@ CrossSync.add_mapping("grpc_helpers", grpc_helpers) +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_client" + @CrossSync.export_sync( - path="tests.unit.data._sync.test_client.TestBigtableDataClient", + sync_name="TestBigtableDataClient", add_mapping_for_name="TestBigtableDataClient", ) class TestBigtableDataClientAsync: @@ -1114,7 +1116,7 @@ def test_client_ctor_sync(self): @CrossSync.export_sync( - path="tests.unit.data._sync.test_client.TestTable", add_mapping_for_name="TestTable" + "TestTable", add_mapping_for_name="TestTable" ) class TestTableAsync: @CrossSync.convert @@ -1428,7 +1430,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ @CrossSync.export_sync( - path="tests.unit.data._sync.test_client.TestReadRows", + "TestReadRows", add_mapping_for_name="TestReadRows", ) class TestReadRowsAsync: @@ -1940,7 +1942,7 @@ async def test_row_exists(self, return_value, expected_result): assert query.filter._to_dict() == expected_filter -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadRowsSharded") +@CrossSync.export_sync("TestReadRowsSharded") class TestReadRowsShardedAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -2160,7 +2162,7 @@ async def mock_call(*args, **kwargs): ) -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestSampleRowKeys") +@CrossSync.export_sync("TestSampleRowKeys") class TestSampleRowKeysAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -2314,9 +2316,7 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio await table.sample_row_keys() -@CrossSync.export_sync( - path="tests.unit.data._sync.test_client.TestMutateRow", -) +@CrossSync.export_sync("TestMutateRow") class TestMutateRowAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -2493,9 +2493,7 @@ async def test_mutate_row_no_mutations(self, mutations): assert e.value.args[0] == "No mutations provided" -@CrossSync.export_sync( - path="tests.unit.data._sync.test_client.TestBulkMutateRows", -) +@CrossSync.export_sync("TestBulkMutateRows") class TestBulkMutateRowsAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -2878,7 +2876,7 @@ async def test_bulk_mutate_error_recovery(self): await table.bulk_mutate_rows(entries, operation_timeout=1000) -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestCheckAndMutateRow") +@CrossSync.export_sync("TestCheckAndMutateRow") class TestCheckAndMutateRowAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -3031,7 +3029,7 @@ async def test_check_and_mutate_mutations_parsing(self): ) -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestReadModifyWriteRow") +@CrossSync.export_sync("TestReadModifyWriteRow") class TestReadModifyWriteRowAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -3163,7 +3161,7 @@ async def test_read_modify_write_row_building(self): constructor_mock.assert_called_once_with(mock_response.row) -@CrossSync.export_sync(path="tests.unit.data._sync.test_client.TestExecuteQuery") +@CrossSync.export_sync("TestExecuteQuery") class TestExecuteQueryAsync: TABLE_NAME = "TABLE_NAME" INSTANCE_NAME = "INSTANCE_NAME" diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 6e5949575..5cddb08f9 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -29,8 +29,10 @@ import mock # type: ignore +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_mutations_batcher" + @CrossSync.export_sync( - path="tests.unit.data._sync.test_mutations_batcher.Test_FlowControl" + sync_name="Test_FlowControl" ) class Test_FlowControl: @staticmethod @@ -302,7 +304,7 @@ async def test_add_to_flow_oversize(self): @CrossSync.export_sync( - path="tests.unit.data._sync.test_mutations_batcher.TestMutationsBatcher" + sync_name="TestMutationsBatcher" ) class TestMutationsBatcherAsync: @CrossSync.convert diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 0bd5d82f8..9032b1a34 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -30,8 +30,10 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_read_rows_acceptance" + @CrossSync.export_sync( - path="tests.unit.data._sync.test_read_rows_acceptance.TestReadRowsAcceptance", + sync_name="TestReadRowsAcceptance", ) class TestReadRowsAcceptanceAsync: @staticmethod diff --git a/tests/unit/data/execute_query/_async/test_query_iterator.py b/tests/unit/data/execute_query/_async/test_query_iterator.py index f55b2f9b5..2990eac1a 100644 --- a/tests/unit/data/execute_query/_async/test_query_iterator.py +++ b/tests/unit/data/execute_query/_async/test_query_iterator.py @@ -30,8 +30,10 @@ import mock # type: ignore +__CROSS_SYNC_OUTPUT__ = "tests.unit.data.execute_query._sync.test_query_iterator" + @CrossSync.export_sync( - path="tests.unit.data.execute_query._sync.test_query_iterator.MockIterator" + sync_name="MockIterator" ) class MockIterator: def __init__(self, values, delay=None): @@ -55,7 +57,7 @@ async def __anext__(self): @CrossSync.export_sync( - path="tests.unit.data.execute_query._sync.test_query_iterator.TestQueryIterator" + sync_name="TestQueryIterator" ) class TestQueryIteratorAsync: @staticmethod From 41576f756d1a5eb9d6d242c8d0e2810c0a06e368 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 10 Sep 2024 18:06:42 -0700 Subject: [PATCH 269/360] fixed issues in generation --- .cross_sync/generate.py | 17 +++++++++-------- .cross_sync/transformers.py | 10 ++++------ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index fea4f04f3..5c130079f 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -54,7 +54,7 @@ def render(self, with_black=True, save_to_disk: bool = False) -> str: "#\n" "# This file is automatically generated by CrossSync. Do not edit manually.\n" ) - full_str = header + ast.unparse(self.converted) + full_str = header + ast.unparse(self.tree) if with_black: import black # type: ignore import autoflake # type: ignore @@ -65,8 +65,8 @@ def render(self, with_black=True, save_to_disk: bool = False) -> str: ) if save_to_disk: import os - os.makedirs(os.path.dirname(self.output_path), exist_ok=True) - with open(self.output_path, "w") as f: + os.makedirs(os.path.dirname(self.file_path), exist_ok=True) + with open(self.file_path, "w") as f: f.write(full_str) return full_str @@ -82,11 +82,12 @@ def convert_files_in_dir(directory: str) -> set[CrossSyncOutputFile]: file_transformer = CrossSyncFileHandler() # run each file through ast transformation to find all annotated classes for file_path in files: - file = open(file_path).read() - converted_tree = file_transformer.visit(ast.parse(file)) - if converted_tree is not None: + ast_tree = ast.parse(open(file_path).read()) + output_path = file_transformer.get_output_path(ast_tree) + if output_path is not None: # contains __CROSS_SYNC_OUTPUT__ annotation - artifacts.add(CrossSyncOutputFile(file_path, converted_tree)) + converted_tree = file_transformer.visit(ast_tree) + artifacts.add(CrossSyncOutputFile(output_path, converted_tree)) # return set of output artifacts return artifacts @@ -101,5 +102,5 @@ def save_artifacts(artifacts: Sequence[CrossSyncOutputFile]): search_root = sys.argv[1] outputs = convert_files_in_dir(search_root) - print(f"Generated {len(outputs)} artifacts: {[a.file_name for a in outputs]}") + print(f"Generated {len(outputs)} artifacts: {[a.file_path for a in outputs]}") save_artifacts(outputs) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 5afef0d41..1744e9da0 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -232,20 +232,18 @@ class CrossSyncFileHandler(ast.NodeTransformer): """ @staticmethod - def _find_cs_output(node): + def get_output_path(node): for i, n in enumerate(node.body): if isinstance(n, ast.Assign): for target in n.targets: if isinstance(target, ast.Name) and target.id == "__CROSS_SYNC_OUTPUT__": # keep the output path - # remove the statement - node.body.pop(i) - return n.value.value + ".py" + return n.value.value.replace(".", "/") + ".py" def visit_Module(self, node): # look for __CROSS_SYNC_OUTPUT__ Assign statement - self.output_path = self._find_cs_output(node) - if self.output_path: + output_path = self.get_output_path(node) + if output_path: # if found, process the file return self.generic_visit(node) else: From 0b9f377a4cfd43a6186374ae020d5927a39cff26 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 10 Sep 2024 18:11:57 -0700 Subject: [PATCH 270/360] fixed lint --- google/cloud/bigtable/data/_async/client.py | 39 +++++++++++++++---- .../_async/execute_query_iterator.py | 12 +++--- tests/system/data/test_system_async.py | 6 ++- tests/unit/data/_async/test__mutate_rows.py | 1 + tests/unit/data/_async/test__read_rows.py | 1 + tests/unit/data/_async/test_client.py | 5 +-- .../data/_async/test_mutations_batcher.py | 10 ++--- .../data/_async/test_read_rows_acceptance.py | 1 + .../_async/test_query_iterator.py | 12 ++---- 9 files changed, 55 insertions(+), 32 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index b1bc3fcfa..85f24ef7e 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -116,8 +116,14 @@ class BigtableDataClientAsync(ClientWithProject): @CrossSync.convert( docstring_format_vars={ - "LOOP_MESSAGE": ("Client should be created within an async context (running event loop)", ""), - "RAISE_NO_LOOP": ("RuntimeError: if called outside of an async context (no running event loop)", ""), + "LOOP_MESSAGE": ( + "Client should be created within an async context (running event loop)", + "", + ), + "RAISE_NO_LOOP": ( + "RuntimeError: if called outside of an async context (no running event loop)", + "", + ), } ) def __init__( @@ -236,7 +242,12 @@ def _client_version() -> str: return version_str @CrossSync.convert( - docstring_format_vars={"RAISE_NO_LOOP": ("RuntimeError: if not called in an asyncio event loop", "None")} + docstring_format_vars={ + "RAISE_NO_LOOP": ( + "RuntimeError: if not called in an asyncio event loop", + "None", + ) + } ) def _start_background_channel_refresh(self) -> None: """ @@ -464,8 +475,14 @@ async def _remove_instance_registration( @CrossSync.convert( replace_symbols={"TableAsync": "Table"}, docstring_format_vars={ - "LOOP_MESSAGE": ("Must be created within an async context (running event loop)", ""), - "RAISE_NO_LOOP": ("RuntimeError: if called outside of an async context (no running event loop)", "None"), + "LOOP_MESSAGE": ( + "Must be created within an async context (running event loop)", + "", + ), + "RAISE_NO_LOOP": ( + "RuntimeError: if called outside of an async context (no running event loop)", + "None", + ), }, ) def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: @@ -633,9 +650,15 @@ class TableAsync: @CrossSync.convert( replace_symbols={"BigtableDataClientAsync": "BigtableDataClient"}, docstring_format_vars={ - "LOOP_MESSAGE": ("Must be created within an async context (running event loop)", ""), - "RAISE_NO_LOOP": ("RuntimeError: if called outside of an async context (no running event loop)", "None"), - } + "LOOP_MESSAGE": ( + "Must be created within an async context (running event loop)", + "", + ), + "RAISE_NO_LOOP": ( + "RuntimeError: if called outside of an async context (no running event loop)", + "None", + ), + }, ) def __init__( self, diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index 46389bb06..1e780ea75 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -14,10 +14,8 @@ from __future__ import annotations -import asyncio from typing import ( Any, - AsyncIterator, Dict, List, Optional, @@ -52,15 +50,19 @@ else: from google.cloud.bigtable.data import BigtableDataClient as DataClientType -__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data.execute_query._sync.execute_query_iterator" +__CROSS_SYNC_OUTPUT__ = ( + "google.cloud.bigtable.data.execute_query._sync.execute_query_iterator" +) @CrossSync.export_sync(sync_name="ExecuteQueryIterator") class ExecuteQueryIteratorAsync: - @CrossSync.convert( docstring_format_vars={ - "NO_LOOP": ("RuntimeError: if the instance is not created within an async event loop context.", "None"), + "NO_LOOP": ( + "RuntimeError: if the instance is not created within an async event loop context.", + "None", + ), "TASK_OR_THREAD": ("asyncio Tasks", "threads"), } ) diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 3493bba24..6d068c152 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -29,6 +29,7 @@ __CROSS_SYNC_OUTPUT__ = "tests.system.data.test_system" + @CrossSync.export_sync( sync_name="TempRowBuilder", add_mapping_for_name="TempRowBuilder", @@ -495,7 +496,10 @@ async def test_mutations_batcher_large_batch(self, client, table, temp_rows): test batcher with large batch of mutations """ from google.cloud.bigtable.data.mutations import RowMutationEntry, SetCell - add_mutation = SetCell(family=TEST_FAMILY, qualifier=b"test-qualifier", new_value=b"a") + + add_mutation = SetCell( + family=TEST_FAMILY, qualifier=b"test-qualifier", new_value=b"a" + ) row_mutations = [] for i in range(50_000): row_key = uuid.uuid4().hex.encode() diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 62f7f88ba..35184652c 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -29,6 +29,7 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test__mutate_rows" + @CrossSync.export_sync("TestMutateRowsOperation") class TestMutateRowsOperation: def _target_class(self): diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index ccecab7fd..c41914e84 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -24,6 +24,7 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test__read_rows" + @CrossSync.export_sync( sync_name="TestReadRowsOperation", ) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 431a689b9..497326175 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -31,6 +31,7 @@ from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule +from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse from google.cloud.bigtable.data._sync.cross_sync import CrossSync @@ -1115,9 +1116,7 @@ def test_client_ctor_sync(self): assert client._channel_refresh_tasks == [] -@CrossSync.export_sync( - "TestTable", add_mapping_for_name="TestTable" -) +@CrossSync.export_sync("TestTable", add_mapping_for_name="TestTable") class TestTableAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 5cddb08f9..0361da5eb 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -31,9 +31,8 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_mutations_batcher" -@CrossSync.export_sync( - sync_name="Test_FlowControl" -) + +@CrossSync.export_sync(sync_name="Test_FlowControl") class Test_FlowControl: @staticmethod @CrossSync.convert @@ -303,9 +302,7 @@ async def test_add_to_flow_oversize(self): assert len(count_results) == 1 -@CrossSync.export_sync( - sync_name="TestMutationsBatcher" -) +@CrossSync.export_sync(sync_name="TestMutationsBatcher") class TestMutationsBatcherAsync: @CrossSync.convert def _get_target_class(self): @@ -1226,6 +1223,7 @@ async def test_large_batch_write(self): Test that a large batch of mutations can be written """ import math + num_mutations = 10_000 flush_limit = 1000 mutations = [self._make_mutation(count=1, size=1)] * num_mutations diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 9032b1a34..bf5af9786 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -32,6 +32,7 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_read_rows_acceptance" + @CrossSync.export_sync( sync_name="TestReadRowsAcceptance", ) diff --git a/tests/unit/data/execute_query/_async/test_query_iterator.py b/tests/unit/data/execute_query/_async/test_query_iterator.py index 2990eac1a..d24c47466 100644 --- a/tests/unit/data/execute_query/_async/test_query_iterator.py +++ b/tests/unit/data/execute_query/_async/test_query_iterator.py @@ -15,9 +15,6 @@ import pytest import concurrent.futures -from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( - ExecuteQueryIteratorAsync, -) from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes @@ -32,9 +29,8 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data.execute_query._sync.test_query_iterator" -@CrossSync.export_sync( - sync_name="MockIterator" -) + +@CrossSync.export_sync(sync_name="MockIterator") class MockIterator: def __init__(self, values, delay=None): self._values = values @@ -56,9 +52,7 @@ async def __anext__(self): return value -@CrossSync.export_sync( - sync_name="TestQueryIterator" -) +@CrossSync.export_sync(sync_name="TestQueryIterator") class TestQueryIteratorAsync: @staticmethod def _target_class(): From 8b13583dd3cba270454fcbf082df8a21096d46bf Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 10 Sep 2024 18:13:07 -0700 Subject: [PATCH 271/360] regenerated files --- .../cloud/bigtable/data/_sync/_mutate_rows.py | 18 ++++++----------- .../cloud/bigtable/data/_sync/_read_rows.py | 8 +++----- google/cloud/bigtable/data/_sync/client.py | 20 ++++++++----------- .../bigtable/data/_sync/mutations_batcher.py | 12 +++-------- .../_sync/execute_query_iterator.py | 8 ++++---- tests/system/data/test_system.py | 2 ++ tests/unit/data/_sync/test__mutate_rows.py | 1 + tests/unit/data/_sync/test__read_rows.py | 1 + tests/unit/data/_sync/test_client.py | 11 ++++------ .../unit/data/_sync/test_mutations_batcher.py | 1 + .../data/_sync/test_read_rows_acceptance.py | 2 ++ .../_sync/test_query_iterator.py | 1 + 12 files changed, 36 insertions(+), 49 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index f36557743..d65d6db57 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -28,20 +28,14 @@ if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry - - if CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable_v2.services.bigtable.async_client import ( - BigtableAsyncClient as GapicClientType, - ) - from google.cloud.bigtable.data._async.client import TableAsync as TableType - else: - from google.cloud.bigtable_v2.services.bigtable.client import ( - BigtableClient as GapicClientType, - ) - from google.cloud.bigtable.data._sync.client import Table as TableType + from google.cloud.bigtable_v2.services.bigtable.client import ( + BigtableClient as GapicClientType, + ) + from google.cloud.bigtable.data._sync.client import Table as TableType +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync._mutate_rows" -class _MutateRowsOperation: +class MutateRowsOperation: """ MutateRowsOperation manages the logic of sending a set of row mutations, and retrying on failed entries. It manages this using the _run_attempt diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index b5fa35479..373ec2884 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -32,13 +32,11 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync if TYPE_CHECKING: - if CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable.data._async.client import TableAsync as TableType - else: - from google.cloud.bigtable.data._sync.client import Table as TableType + from google.cloud.bigtable.data._sync.client import Table as TableType +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync._read_rows" -class _ReadRowsOperation: +class ReadRowsOperation: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream into a stream of Row objects. diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index bee3d5c7e..c76afcecf 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -64,22 +64,18 @@ from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter from google.cloud.bigtable.data.row_filters import RowFilterChain from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( + PooledBigtableGrpcTransport as PooledTransportType, +) +from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher +from google.cloud.bigtable.data.execute_query._sync.execute_query_iterator import ( + ExecuteQueryIterator, +) -if CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( - PooledBigtableGrpcAsyncIOTransport as PooledTransportType, - ) -else: - from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledBigtableGrpcTransport as PooledTransportType, - ) - from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher - from google.cloud.bigtable.data.execute_query._sync.execute_query_iterator import ( - ExecuteQueryIterator, - ) if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync.client" @CrossSync._Sync_Impl.add_mapping_decorator("DataClient") diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 9b9c750da..65e5c9d52 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -13,9 +13,6 @@ # limitations under the License. # # This file is automatically generated by CrossSync. Do not edit manually. - -# mypy: disable-error-code="unreachable" - from __future__ import annotations from typing import Sequence, TYPE_CHECKING import atexit @@ -34,11 +31,8 @@ if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry - - if CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable.data._async.client import TableAsync as TableType - else: - from google.cloud.bigtable.data._sync.client import Table as TableType + from google.cloud.bigtable.data._sync.client import Table as TableType +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync.mutations_batcher" @CrossSync._Sync_Impl.add_mapping_decorator("_FlowControl") @@ -223,7 +217,7 @@ def __init__( else None ) self._sync_flush_executor = ( - concurrent.futures.ThreadPoolExecutor(max_workers=1) + concurrent.futures.ThreadPoolExecutor(max_workers=4) if not CrossSync._Sync_Impl.is_async else None ) diff --git a/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py index 691675c53..66ac80a9e 100644 --- a/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py @@ -34,10 +34,10 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync if TYPE_CHECKING: - if CrossSync._Sync_Impl.is_async: - from google.cloud.bigtable.data import BigtableDataClientAsync as DataClientType - else: - from google.cloud.bigtable.data import BigtableDataClient as DataClientType + from google.cloud.bigtable.data import BigtableDataClient as DataClientType +__CROSS_SYNC_OUTPUT__ = ( + "google.cloud.bigtable.data.execute_query._sync.execute_query_iterator" +) class ExecuteQueryIterator: diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py index 30dc052c9..43f284128 100644 --- a/tests/system/data/test_system.py +++ b/tests/system/data/test_system.py @@ -23,6 +23,8 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync from . import TEST_FAMILY, TEST_FAMILY_2 +__CROSS_SYNC_OUTPUT__ = "tests.system.data.test_system" + @CrossSync._Sync_Impl.add_mapping_decorator("TempRowBuilder") class TempRowBuilder: diff --git a/tests/unit/data/_sync/test__mutate_rows.py b/tests/unit/data/_sync/test__mutate_rows.py index 73c714246..ed2ec4683 100644 --- a/tests/unit/data/_sync/test__mutate_rows.py +++ b/tests/unit/data/_sync/test__mutate_rows.py @@ -24,6 +24,7 @@ from unittest import mock except ImportError: import mock +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test__mutate_rows" class TestMutateRowsOperation: diff --git a/tests/unit/data/_sync/test__read_rows.py b/tests/unit/data/_sync/test__read_rows.py index a71b1bf2b..73b34c631 100644 --- a/tests/unit/data/_sync/test__read_rows.py +++ b/tests/unit/data/_sync/test__read_rows.py @@ -20,6 +20,7 @@ from unittest import mock except ImportError: import mock +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test__read_rows" class TestReadRowsOperation: diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 570786796..24a3a1cfe 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -28,20 +28,17 @@ from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule +from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse from google.cloud.bigtable.data._sync.cross_sync import CrossSync try: from unittest import mock except ImportError: import mock -if CrossSync._Sync_Impl.is_async: - from google.api_core import grpc_helpers_async +from google.api_core import grpc_helpers - CrossSync._Sync_Impl.add_mapping("grpc_helpers", grpc_helpers_async) -else: - from google.api_core import grpc_helpers - - CrossSync._Sync_Impl.add_mapping("grpc_helpers", grpc_helpers) +CrossSync.add_mapping("grpc_helpers", grpc_helpers) +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_client" @CrossSync._Sync_Impl.add_mapping_decorator("TestBigtableDataClient") diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py index 98d23da85..0b23a1ac0 100644 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -26,6 +26,7 @@ from unittest import mock except ImportError: import mock +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_mutations_batcher" class Test_FlowControl: diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync/test_read_rows_acceptance.py index ccb4f42e0..ce0b544f6 100644 --- a/tests/unit/data/_sync/test_read_rows_acceptance.py +++ b/tests/unit/data/_sync/test_read_rows_acceptance.py @@ -25,6 +25,8 @@ from ...v2_client.test_row_merger import ReadRowsTest, TestFile from google.cloud.bigtable.data._sync.cross_sync import CrossSync +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_read_rows_acceptance" + class TestReadRowsAcceptance: diff --git a/tests/unit/data/execute_query/_sync/test_query_iterator.py b/tests/unit/data/execute_query/_sync/test_query_iterator.py index 8e52a1d76..27640447e 100644 --- a/tests/unit/data/execute_query/_sync/test_query_iterator.py +++ b/tests/unit/data/execute_query/_sync/test_query_iterator.py @@ -23,6 +23,7 @@ from unittest import mock except ImportError: import mock +__CROSS_SYNC_OUTPUT__ = "tests.unit.data.execute_query._sync.test_query_iterator" class MockIterator: From 7eb6ac4b36a8113706843808ec55d9d27b117388 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 11 Sep 2024 11:30:34 -0700 Subject: [PATCH 272/360] use consistent black version --- noxfile.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/noxfile.py b/noxfile.py index ce6ddcde2..3d2fea000 100644 --- a/noxfile.py +++ b/noxfile.py @@ -41,8 +41,8 @@ "pytest", "pytest-cov", "pytest-asyncio", - "black", "autoflake", + BLACK_VERSION, ] UNIT_TEST_EXTERNAL_DEPENDENCIES: List[str] = [] UNIT_TEST_LOCAL_DEPENDENCIES: List[str] = [] @@ -58,7 +58,7 @@ ] SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ "pytest-asyncio==0.21.2", - "black==23.7.0", + BLACK_VERSION, "pyyaml==6.0.2", ] SYSTEM_TEST_LOCAL_DEPENDENCIES: List[str] = [] @@ -555,10 +555,12 @@ def prerelease_deps(session, protobuf_implementation): }, ) -@nox.session(python="3.11") + +@nox.session(python="3.10") def generate_sync(session): """ Re-generate sync files for the library from CrossSync-annotated async source """ - session.install("black", "autoflake") + session.install(BLACK_VERSION) + session.install("autoflake") session.run("python", ".cross_sync/generate.py", ".") From 0efb7f659c154f1ec34983e1b05f75f8f09accc9 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Wed, 11 Sep 2024 18:33:51 +0000 Subject: [PATCH 273/360] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot?= =?UTF-8?q?=20post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- .../cloud/bigtable/data/_sync/_mutate_rows.py | 3 +- google/cloud/bigtable/data/_sync/client.py | 70 ++++++++----------- .../bigtable/data/_sync/mutations_batcher.py | 19 +++-- .../_sync/execute_query_iterator.py | 1 - tests/system/data/test_system.py | 29 ++++---- tests/unit/data/_sync/test__mutate_rows.py | 3 +- tests/unit/data/_sync/test__read_rows.py | 8 +-- tests/unit/data/_sync/test_client.py | 51 ++++++-------- .../unit/data/_sync/test_mutations_batcher.py | 16 ++--- .../data/_sync/test_read_rows_acceptance.py | 6 -- .../_sync/test_query_iterator.py | 2 - tests/unit/data/test_sync_up_to_date.py | 4 +- 12 files changed, 87 insertions(+), 125 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index d65d6db57..ff94691a7 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -136,7 +136,8 @@ def _run_attempt(self): GoogleAPICallError: if the gapic rpc fails""" request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] active_request_indices = { - req_idx: orig_idx for req_idx, orig_idx in enumerate(self.remaining_indices) + req_idx: orig_idx + for (req_idx, orig_idx) in enumerate(self.remaining_indices) } self.remaining_indices = [] if not request_entries: diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index c76afcecf..cc4c03e5f 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -80,16 +80,15 @@ @CrossSync._Sync_Impl.add_mapping_decorator("DataClient") class BigtableDataClient(ClientWithProject): - def __init__( self, *, project: str | None = None, pool_size: int = 3, credentials: google.auth.credentials.Credentials | None = None, - client_options: ( - dict[str, Any] | "google.api_core.client_options.ClientOptions" | None - ) = None, + client_options: dict[str, Any] + | "google.api_core.client_options.ClientOptions" + | None = None, ): """Create a client instance for the Bigtable Data API @@ -243,7 +242,7 @@ def _ping_and_warm_instances( ], wait_for_ready=True, ) - for instance_name, table_name, app_profile_id in instance_list + for (instance_name, table_name, app_profile_id) in instance_list ] result_list = CrossSync._Sync_Impl.gather_partials( partial_list, return_exceptions=True, sync_executor=self._executor @@ -618,9 +617,8 @@ def read_rows_stream( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: ( - Sequence[type[Exception]] | TABLE_DEFAULT - ) = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> Iterable[Row]: """Read a set of rows from the table, based on the specified query. Returns an iterator to asynchronously stream back row data. @@ -648,7 +646,7 @@ def read_rows_stream( from any retries that failed google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error """ - operation_timeout, attempt_timeout = _get_timeouts( + (operation_timeout, attempt_timeout) = _get_timeouts( operation_timeout, attempt_timeout, self ) retryable_excs = _get_retryable_errors(retryable_errors, self) @@ -667,9 +665,8 @@ def read_rows( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: ( - Sequence[type[Exception]] | TABLE_DEFAULT - ) = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: """Read a set of rows from the table, based on the specified query. Retruns results as a list of Row objects when the request is complete. @@ -715,9 +712,8 @@ def read_row( row_filter: RowFilter | None = None, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: ( - Sequence[type[Exception]] | TABLE_DEFAULT - ) = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> Row | None: """Read a single row from the table, based on the specified key. @@ -763,9 +759,8 @@ def read_rows_sharded( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: ( - Sequence[type[Exception]] | TABLE_DEFAULT - ) = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: """Runs a sharded query in parallel, then return the results in a single list. Results will be returned in the order of the input queries. @@ -797,7 +792,7 @@ def read_rows_sharded( ValueError: if the query_list is empty""" if not sharded_query: raise ValueError("empty sharded_query") - operation_timeout, attempt_timeout = _get_timeouts( + (operation_timeout, attempt_timeout) = _get_timeouts( operation_timeout, attempt_timeout, self ) rpc_timeout_generator = _attempt_timeout_generator( @@ -840,7 +835,7 @@ def read_rows_with_semaphore(query): raise ShardedReadRowsExceptionGroup( [ FailedQueryShardError(idx, sharded_query[idx], e) - for idx, e in error_dict.items() + for (idx, e) in error_dict.items() ], results_list, len(sharded_query), @@ -853,9 +848,8 @@ def row_exists( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, - retryable_errors: ( - Sequence[type[Exception]] | TABLE_DEFAULT - ) = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> bool: """Return a boolean indicating whether the specified row exists in the table. uses the filters: chain(limit cells per row = 1, strip value) @@ -899,9 +893,8 @@ def sample_row_keys( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - retryable_errors: ( - Sequence[type[Exception]] | TABLE_DEFAULT - ) = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> RowKeySamples: """Return a set of RowKeySamples that delimit contiguous sections of the table of approximately equal size @@ -932,7 +925,7 @@ def sample_row_keys( from any retries that failed google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error """ - operation_timeout, attempt_timeout = _get_timeouts( + (operation_timeout, attempt_timeout) = _get_timeouts( operation_timeout, attempt_timeout, self ) attempt_timeout_gen = _attempt_timeout_generator( @@ -973,9 +966,8 @@ def mutations_batcher( flow_control_max_bytes: int = 100 * _MB_SIZE, batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - batch_retryable_errors: ( - Sequence[type[Exception]] | TABLE_DEFAULT - ) = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, ) -> MutationsBatcher: """Returns a new mutations batcher instance. @@ -1019,9 +1011,8 @@ def mutate_row( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, - retryable_errors: ( - Sequence[type[Exception]] | TABLE_DEFAULT - ) = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ): """Mutates a row atomically. @@ -1052,7 +1043,7 @@ def mutate_row( google.api_core.exceptions.GoogleAPIError: raised on non-idempotent operations that cannot be safely retried. ValueError: if invalid arguments are provided""" - operation_timeout, attempt_timeout = _get_timeouts( + (operation_timeout, attempt_timeout) = _get_timeouts( operation_timeout, attempt_timeout, self ) if not mutations: @@ -1091,9 +1082,8 @@ def bulk_mutate_rows( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - retryable_errors: ( - Sequence[type[Exception]] | TABLE_DEFAULT - ) = TABLE_DEFAULT.MUTATE_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, ): """Applies mutations for multiple rows in a single batched request. @@ -1124,7 +1114,7 @@ def bulk_mutate_rows( MutationsExceptionGroup: if one or more mutations fails Contains details about any failed entries in .exceptions ValueError: if invalid arguments are provided""" - operation_timeout, attempt_timeout = _get_timeouts( + (operation_timeout, attempt_timeout) = _get_timeouts( operation_timeout, attempt_timeout, self ) retryable_excs = _get_retryable_errors(retryable_errors, self) @@ -1177,7 +1167,7 @@ def check_and_mutate_row( bool indicating whether the predicate was true or false Raises: google.api_core.exceptions.GoogleAPIError: exceptions from grpc call""" - operation_timeout, _ = _get_timeouts(operation_timeout, None, self) + (operation_timeout, _) = _get_timeouts(operation_timeout, None, self) if true_case_mutations is not None and ( not isinstance(true_case_mutations, list) ): @@ -1232,7 +1222,7 @@ def read_modify_write_row( Raises: google.api_core.exceptions.GoogleAPIError: exceptions from grpc call ValueError: if invalid arguments are provided""" - operation_timeout, _ = _get_timeouts(operation_timeout, None, self) + (operation_timeout, _) = _get_timeouts(operation_timeout, None, self) if operation_timeout <= 0: raise ValueError("operation_timeout must be greater than 0") if rules is not None and (not isinstance(rules, list)): diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 65e5c9d52..924418201 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -188,11 +188,10 @@ def __init__( flow_control_max_bytes: int = 100 * _MB_SIZE, batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - batch_retryable_errors: ( - Sequence[type[Exception]] | TABLE_DEFAULT - ) = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, ): - self._operation_timeout, self._attempt_timeout = _get_timeouts( + (self._operation_timeout, self._attempt_timeout) = _get_timeouts( batch_operation_timeout, batch_attempt_timeout, table ) self._retryable_errors: list[type[Exception]] = _get_retryable_errors( @@ -201,7 +200,7 @@ def __init__( self._closed = CrossSync._Sync_Impl.Event() self._table = table self._staged_entries: list[RowMutationEntry] = [] - self._staged_count, self._staged_bytes = (0, 0) + (self._staged_count, self._staged_bytes) = (0, 0) self._flow_control = CrossSync._Sync_Impl._FlowControl( flow_control_max_mutation_count, flow_control_max_bytes ) @@ -282,8 +281,8 @@ def _schedule_flush(self) -> CrossSync._Sync_Impl.Future[None] | None: Future[None] | None: future representing the background task, if started""" if self._staged_entries: - entries, self._staged_entries = (self._staged_entries, []) - self._staged_count, self._staged_bytes = (0, 0) + (entries, self._staged_entries) = (self._staged_entries, []) + (self._staged_count, self._staged_bytes) = (0, 0) new_task = CrossSync._Sync_Impl.create_task( self._flush_internal, entries, sync_executor=self._sync_flush_executor ) @@ -362,14 +361,14 @@ def _raise_exceptions(self): Raises: MutationsExceptionGroup: exception group with all unreported exceptions""" if self._oldest_exceptions or self._newest_exceptions: - oldest, self._oldest_exceptions = (self._oldest_exceptions, []) + (oldest, self._oldest_exceptions) = (self._oldest_exceptions, []) newest = list(self._newest_exceptions) self._newest_exceptions.clear() - entry_count, self._entries_processed_since_last_raise = ( + (entry_count, self._entries_processed_since_last_raise) = ( self._entries_processed_since_last_raise, 0, ) - exc_count, self._exceptions_since_last_raise = ( + (exc_count, self._exceptions_since_last_raise) = ( self._exceptions_since_last_raise, 0, ) diff --git a/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py index 66ac80a9e..7523e11d6 100644 --- a/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py @@ -41,7 +41,6 @@ class ExecuteQueryIterator: - def __init__( self, client: DataClientType, diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py index 43f284128..0c77623cf 100644 --- a/tests/system/data/test_system.py +++ b/tests/system/data/test_system.py @@ -72,7 +72,6 @@ def delete_rows(self): class TestSystem: - @pytest.fixture(scope="session") def client(self): project = os.getenv("GOOGLE_CLOUD_PROJECT") or None @@ -175,7 +174,7 @@ def test_mutation_set_cell(self, table, temp_rows): """Ensure cells can be set properly""" row_key = b"bulk_mutate" new_value = uuid.uuid4().hex.encode() - row_key, mutation = self._create_row_and_mutation( + (row_key, mutation) = self._create_row_and_mutation( table, temp_rows, new_value=new_value ) table.mutate_row(row_key, mutation) @@ -208,7 +207,7 @@ def test_bulk_mutations_set_cell(self, client, table, temp_rows): from google.cloud.bigtable.data.mutations import RowMutationEntry new_value = uuid.uuid4().hex.encode() - row_key, mutation = self._create_row_and_mutation( + (row_key, mutation) = self._create_row_and_mutation( table, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) @@ -243,11 +242,11 @@ def test_mutations_batcher_context_manager(self, client, table, temp_rows): """test batcher with context manager. Should flush on exit""" from google.cloud.bigtable.data.mutations import RowMutationEntry - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = self._create_row_and_mutation( + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( table, temp_rows, new_value=new_value ) - row_key2, mutation2 = self._create_row_and_mutation( + (row_key2, mutation2) = self._create_row_and_mutation( table, temp_rows, new_value=new_value2 ) bulk_mutation = RowMutationEntry(row_key, [mutation]) @@ -268,7 +267,7 @@ def test_mutations_batcher_timer_flush(self, client, table, temp_rows): from google.cloud.bigtable.data.mutations import RowMutationEntry new_value = uuid.uuid4().hex.encode() - row_key, mutation = self._create_row_and_mutation( + (row_key, mutation) = self._create_row_and_mutation( table, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) @@ -290,12 +289,12 @@ def test_mutations_batcher_count_flush(self, client, table, temp_rows): """batch should flush after flush_limit_mutation_count mutations""" from google.cloud.bigtable.data.mutations import RowMutationEntry - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = self._create_row_and_mutation( + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( table, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = self._create_row_and_mutation( + (row_key2, mutation2) = self._create_row_and_mutation( table, temp_rows, new_value=new_value2 ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) @@ -322,12 +321,12 @@ def test_mutations_batcher_bytes_flush(self, client, table, temp_rows): """batch should flush after flush_limit_bytes bytes""" from google.cloud.bigtable.data.mutations import RowMutationEntry - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = self._create_row_and_mutation( + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( table, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = self._create_row_and_mutation( + (row_key2, mutation2) = self._create_row_and_mutation( table, temp_rows, new_value=new_value2 ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) @@ -353,11 +352,11 @@ def test_mutations_batcher_no_flush(self, client, table, temp_rows): new_value = uuid.uuid4().hex.encode() start_value = b"unchanged" - row_key, mutation = self._create_row_and_mutation( + (row_key, mutation) = self._create_row_and_mutation( table, temp_rows, start_value=start_value, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = self._create_row_and_mutation( + (row_key2, mutation2) = self._create_row_and_mutation( table, temp_rows, start_value=start_value, new_value=new_value ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) diff --git a/tests/unit/data/_sync/test__mutate_rows.py b/tests/unit/data/_sync/test__mutate_rows.py index ed2ec4683..33510630c 100644 --- a/tests/unit/data/_sync/test__mutate_rows.py +++ b/tests/unit/data/_sync/test__mutate_rows.py @@ -28,7 +28,6 @@ class TestMutateRowsOperation: - def _target_class(self): return CrossSync._Sync_Impl._MutateRowsOperation @@ -263,7 +262,7 @@ def test_run_attempt_single_entry_success(self): instance._run_attempt() assert len(instance.remaining_indices) == 0 assert mock_gapic_fn.call_count == 1 - _, kwargs = mock_gapic_fn.call_args + (_, kwargs) = mock_gapic_fn.call_args assert kwargs["timeout"] == expected_timeout assert kwargs["entries"] == [mutation._to_pb()] diff --git a/tests/unit/data/_sync/test__read_rows.py b/tests/unit/data/_sync/test__read_rows.py index 73b34c631..84dd95d96 100644 --- a/tests/unit/data/_sync/test__read_rows.py +++ b/tests/unit/data/_sync/test__read_rows.py @@ -171,11 +171,12 @@ def test_revise_request_rowset_ranges( next_key = (last_key + "a").encode("utf-8") last_key = last_key.encode("utf-8") in_ranges = [ - RowRangePB(**{k: v.encode("utf-8") for k, v in r.items()}) + RowRangePB(**{k: v.encode("utf-8") for (k, v) in r.items()}) for r in in_ranges ] expected = [ - RowRangePB(**{k: v.encode("utf-8") for k, v in r.items()}) for r in expected + RowRangePB(**{k: v.encode("utf-8") for (k, v) in r.items()}) + for r in expected ] if with_key: row_keys = [next_key] @@ -238,7 +239,6 @@ def test_revise_limit(self, start_limit, emit_num, expected_limit): from google.cloud.bigtable_v2.types import ReadRowsResponse def awaitable_stream(): - def mock_stream(): for i in range(emit_num): yield ReadRowsResponse( @@ -274,7 +274,6 @@ def test_revise_limit_over_limit(self, start_limit, emit_num): from google.cloud.bigtable.data.exceptions import InvalidChunk def awaitable_stream(): - def mock_stream(): for i in range(emit_num): yield ReadRowsResponse( @@ -333,7 +332,6 @@ def test_retryable_ignore_repeated_rows(self): row_key = b"duplicate" def mock_awaitable_stream(): - def mock_stream(): while True: yield ReadRowsResponse( diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 24a3a1cfe..d2936b714 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -43,7 +43,6 @@ @CrossSync._Sync_Impl.add_mapping_decorator("TestBigtableDataClient") class TestBigtableDataClient: - @staticmethod def _get_target_class(): return CrossSync._Sync_Impl.DataClient @@ -314,9 +313,11 @@ def test__ping_and_warm_instances(self): gather.assert_awaited_once() grpc_call_args = channel.unary_unary().call_args_list for idx, (_, kwargs) in enumerate(grpc_call_args): - expected_instance, expected_table, expected_app_profile = ( - client_mock._active_instances[idx] - ) + ( + expected_instance, + expected_table, + expected_app_profile, + ) = client_mock._active_instances[idx] request = kwargs["request"] assert request["name"] == expected_instance assert request["app_profile_id"] == expected_app_profile @@ -405,9 +406,9 @@ def test__manage_channel_ping_and_warm(self): ) with mock.patch.object(*sleep_tuple): client_mock.transport.replace_channel.side_effect = asyncio.CancelledError - ping_and_warm = client_mock._ping_and_warm_instances = ( - CrossSync._Sync_Impl.Mock() - ) + ping_and_warm = ( + client_mock._ping_and_warm_instances + ) = CrossSync._Sync_Impl.Mock() try: channel_idx = 1 self._get_target_class()._manage_channel(client_mock, channel_idx, 10) @@ -552,7 +553,7 @@ def test__manage_channel_refresh(self, num_cycles): assert create_channel.call_count == num_cycles assert replace_channel.call_count == num_cycles for call in replace_channel.call_args_list: - args, kwargs = call + (args, kwargs) = call assert args[0] == channel_idx assert kwargs["grace"] == expected_grace assert kwargs["new_channel"] == new_channel @@ -958,7 +959,6 @@ def test_context_manager(self): @CrossSync._Sync_Impl.add_mapping_decorator("TestTable") class TestTable: - def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -1272,7 +1272,6 @@ def _make_gapic_stream( from google.cloud.bigtable_v2 import ReadRowsResponse class mock_stream: - def __init__(self, chunk_list, sleep_time): self.chunk_list = chunk_list self.idx = -1 @@ -1563,7 +1562,7 @@ def test_read_row(self): ) assert row == expected_result assert read_rows.call_count == 1 - args, kwargs = read_rows.call_args_list[0] + (args, kwargs) = read_rows.call_args_list[0] assert kwargs["operation_timeout"] == expected_op_timeout assert kwargs["attempt_timeout"] == expected_req_timeout assert len(args) == 1 @@ -1594,7 +1593,7 @@ def test_read_row_w_filter(self): ) assert row == expected_result assert read_rows.call_count == 1 - args, kwargs = read_rows.call_args_list[0] + (args, kwargs) = read_rows.call_args_list[0] assert kwargs["operation_timeout"] == expected_op_timeout assert kwargs["attempt_timeout"] == expected_req_timeout assert len(args) == 1 @@ -1621,7 +1620,7 @@ def test_read_row_no_response(self): ) assert result is None assert read_rows.call_count == 1 - args, kwargs = read_rows.call_args_list[0] + (args, kwargs) = read_rows.call_args_list[0] assert kwargs["operation_timeout"] == expected_op_timeout assert kwargs["attempt_timeout"] == expected_req_timeout assert isinstance(args[0], ReadRowsQuery) @@ -1650,7 +1649,7 @@ def test_row_exists(self, return_value, expected_result): ) assert expected_result == result assert read_rows.call_count == 1 - args, kwargs = read_rows.call_args_list[0] + (args, kwargs) = read_rows.call_args_list[0] assert kwargs["operation_timeout"] == expected_op_timeout assert kwargs["attempt_timeout"] == expected_req_timeout assert isinstance(args[0], ReadRowsQuery) @@ -1670,7 +1669,6 @@ def test_row_exists(self, return_value, expected_result): class TestReadRowsSharded: - def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -1783,7 +1781,7 @@ def mock_call(*args, **kwargs): assert read_rows.call_count == num_queries rpc_start_list = [ starting_timeout - kwargs["operation_timeout"] - for _, kwargs in read_rows.call_args_list + for (_, kwargs) in read_rows.call_args_list ] eps = 0.01 assert all( @@ -1856,7 +1854,6 @@ def mock_call(*args, **kwargs): class TestSampleRowKeys: - def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -1914,7 +1911,7 @@ def test_sample_row_keys_default_timeout(self): ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) result = table.sample_row_keys() - _, kwargs = sample_row_keys.call_args + (_, kwargs) = sample_row_keys.call_args assert abs(kwargs["timeout"] - expected_timeout) < 0.1 assert result == [] assert kwargs["retry"] is None @@ -1936,7 +1933,7 @@ def test_sample_row_keys_gapic_params(self): ) as sample_row_keys: sample_row_keys.return_value = self._make_gapic_stream([]) table.sample_row_keys(attempt_timeout=expected_timeout) - args, kwargs = sample_row_keys.call_args + (args, kwargs) = sample_row_keys.call_args assert len(args) == 0 assert len(kwargs) == 5 assert kwargs["timeout"] == expected_timeout @@ -1995,7 +1992,6 @@ def test_sample_row_keys_non_retryable_errors(self, non_retryable_exception): class TestMutateRow: - def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -2148,7 +2144,6 @@ def test_mutate_row_no_mutations(self, mutations): class TestBulkMutateRows: - def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -2485,7 +2480,6 @@ def test_bulk_mutate_error_recovery(self): class TestCheckAndMutateRow: - def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -2630,7 +2624,6 @@ def test_check_and_mutate_mutations_parsing(self): class TestReadModifyWriteRow: - def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) @@ -2756,9 +2749,7 @@ def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) def _make_gapic_stream(self, sample_list: list["ExecuteQueryResponse" | Exception]): - class MockStream: - def __init__(self, sample_list): self.sample_list = sample_list @@ -2791,7 +2782,7 @@ def resonse_with_metadata(self): "proto_schema": { "columns": [ {"name": name, "type_": {_type: {}}} - for name, _type in schema.items() + for (name, _type) in schema.items() ] } } @@ -2813,11 +2804,9 @@ def resonse_with_result(self, *args, resume_token=None): else: pb_value = PBValue( { - ( - "int_value" - if isinstance(column_value, int) - else "string_value" - ): column_value + "int_value" + if isinstance(column_value, int) + else "string_value": column_value } ) values.append(pb_value) diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py index 0b23a1ac0..2064df07c 100644 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -30,7 +30,6 @@ class Test_FlowControl: - @staticmethod def _target_class(): return CrossSync._Sync_Impl._FlowControl @@ -260,7 +259,6 @@ def test_add_to_flow_oversize(self): class TestMutationsBatcher: - def _get_target_class(self): return CrossSync._Sync_Impl.MutationsBatcher @@ -454,9 +452,9 @@ def test__start_flush_timer_w_empty_input(self, input_val): ) as flush_mock: with self._make_one() as instance: if CrossSync._Sync_Impl.is_async: - sleep_obj, sleep_method = (asyncio, "wait_for") + (sleep_obj, sleep_method) = (asyncio, "wait_for") else: - sleep_obj, sleep_method = (instance._closed, "wait") + (sleep_obj, sleep_method) = (instance._closed, "wait") with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: result = instance._timer_routine(input_val) assert sleep_mock.call_count == 0 @@ -473,9 +471,9 @@ def test__start_flush_timer_call_when_closed(self): instance.close() flush_mock.reset_mock() if CrossSync._Sync_Impl.is_async: - sleep_obj, sleep_method = (asyncio, "wait_for") + (sleep_obj, sleep_method) = (asyncio, "wait_for") else: - sleep_obj, sleep_method = (instance._closed, "wait") + (sleep_obj, sleep_method) = (instance._closed, "wait") with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: instance._timer_routine(10) assert sleep_mock.call_count == 0 @@ -838,7 +836,7 @@ def test__execute_mutate_rows(self): batch = [self._make_mutation()] result = instance._execute_mutate_rows(batch) assert start_operation.call_count == 1 - args, kwargs = mutate_rows.call_args + (args, kwargs) = mutate_rows.call_args assert args[0] == table.client._gapic_client assert args[1] == table assert args[2] == batch @@ -887,7 +885,7 @@ def test__raise_exceptions(self): assert list(exc.exceptions) == expected_exceptions assert str(expected_total) in str(exc) assert instance._entries_processed_since_last_raise == 0 - instance._oldest_exceptions, instance._newest_exceptions = ([], []) + (instance._oldest_exceptions, instance._newest_exceptions) = ([], []) instance._raise_exceptions() def test___enter__(self): @@ -929,7 +927,7 @@ def test_close_w_exceptions(self): assert list(exc.exceptions) == expected_exceptions assert str(expected_total) in str(exc) assert instance._entries_processed_since_last_raise == 0 - instance._oldest_exceptions, instance._newest_exceptions = ([], []) + (instance._oldest_exceptions, instance._newest_exceptions) = ([], []) def test__on_exit(self, recwarn): """Should raise warnings if unflushed mutations exist""" diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync/test_read_rows_acceptance.py index ce0b544f6..9716391eb 100644 --- a/tests/unit/data/_sync/test_read_rows_acceptance.py +++ b/tests/unit/data/_sync/test_read_rows_acceptance.py @@ -29,7 +29,6 @@ class TestReadRowsAcceptance: - @staticmethod def _get_operation_class(): return CrossSync._Sync_Impl._ReadRowsOperation @@ -67,7 +66,6 @@ def _coro_wrapper(stream): return stream def _process_chunks(self, *chunks): - def _row_stream(): yield ReadRowsResponse(chunks=chunks) @@ -87,7 +85,6 @@ def _row_stream(): "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description ) def test_row_merger_scenario(self, test_case: ReadRowsTest): - def _scenerio_stream(): for chunk in test_case.chunks: yield ReadRowsResponse(chunks=[chunk]) @@ -121,12 +118,10 @@ def _scenerio_stream(): "test_case", parse_readrows_acceptance_tests(), ids=lambda t: t.description ) def test_read_rows_scenario(self, test_case: ReadRowsTest): - def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): from google.cloud.bigtable_v2 import ReadRowsResponse class mock_stream: - def __init__(self, chunk_list): self.chunk_list = chunk_list self.idx = -1 @@ -182,7 +177,6 @@ def cancel(self): assert actual == expected def test_out_of_order_rows(self): - def _row_stream(): yield ReadRowsResponse(last_scanned_row_key=b"a") diff --git a/tests/unit/data/execute_query/_sync/test_query_iterator.py b/tests/unit/data/execute_query/_sync/test_query_iterator.py index 27640447e..8621d7dd5 100644 --- a/tests/unit/data/execute_query/_sync/test_query_iterator.py +++ b/tests/unit/data/execute_query/_sync/test_query_iterator.py @@ -27,7 +27,6 @@ class MockIterator: - def __init__(self, values, delay=None): self._values = values self.idx = 0 @@ -47,7 +46,6 @@ def __next__(self): class TestQueryIterator: - @staticmethod def _target_class(): return CrossSync._Sync_Impl.ExecuteQueryIterator diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py index cb616f130..e183e2410 100644 --- a/tests/unit/data/test_sync_up_to_date.py +++ b/tests/unit/data/test_sync_up_to_date.py @@ -40,9 +40,7 @@ def test_sync_up_to_date(artifact): new_render = artifact.render() found_render = open(path).read() # compare by content - diff = unified_diff( - found_render.splitlines(), new_render.splitlines(), lineterm="" - ) + diff = unified_diff(found_render.splitlines(), new_render.splitlines(), lineterm="") diff_str = "\n".join(diff) assert not diff_str, f"Found differences:\n{diff_str}" # compare by hash From d4224cc74d4bf4b8377fc596240dc7585d493f60 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 11 Sep 2024 11:53:37 -0700 Subject: [PATCH 274/360] changed black version --- noxfile.py | 4 ++-- tests/unit/data/test_sync_up_to_date.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/noxfile.py b/noxfile.py index 3d2fea000..5b3e6d992 100644 --- a/noxfile.py +++ b/noxfile.py @@ -28,7 +28,7 @@ import nox FLAKE8_VERSION = "flake8==6.1.0" -BLACK_VERSION = "black[jupyter]==23.7.0" +BLACK_VERSION = "black[jupyter]==23.3.0" ISORT_VERSION = "isort==5.11.0" LINT_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] @@ -41,8 +41,8 @@ "pytest", "pytest-cov", "pytest-asyncio", - "autoflake", BLACK_VERSION, + "autoflake", ] UNIT_TEST_EXTERNAL_DEPENDENCIES: List[str] = [] UNIT_TEST_LOCAL_DEPENDENCIES: List[str] = [] diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py index e183e2410..d72c70988 100644 --- a/tests/unit/data/test_sync_up_to_date.py +++ b/tests/unit/data/test_sync_up_to_date.py @@ -26,10 +26,12 @@ from generate import convert_files_in_dir # noqa: E402 +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="ast.unparse is only available in 3.9+" +) @pytest.mark.parametrize( "artifact", convert_files_in_dir(repo_root), ids=lambda a: a.file_path ) -@pytest.mark.skipif(sys.version_info < (3, 11), reason="generation uses python3.11") def test_sync_up_to_date(artifact): """ Generate a fresh copy of each cross_sync file, and compare hashes with the existing file. From c8d91c07a445ec2d06ca1bef72e2e8058612105b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 11 Sep 2024 15:43:42 -0700 Subject: [PATCH 275/360] fixed underscores in annotations --- google/cloud/bigtable/data/_async/_mutate_rows.py | 2 +- google/cloud/bigtable/data/_async/_read_rows.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 3d79361bc..f0a7c050b 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -47,7 +47,7 @@ __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync._mutate_rows" -@CrossSync.export_sync("MutateRowsOperation") +@CrossSync.export_sync("_MutateRowsOperation") class _MutateRowsOperationAsync: """ MutateRowsOperation manages the logic of sending a set of row mutations, diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index a2d8195a6..e8219f75c 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -45,7 +45,7 @@ __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync._read_rows" -@CrossSync.export_sync("ReadRowsOperation") +@CrossSync.export_sync("_ReadRowsOperation") class _ReadRowsOperationAsync: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream From ef00397429f24736c4311e70204b62230384d1c7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 11 Sep 2024 15:44:02 -0700 Subject: [PATCH 276/360] fixed string access --- .cross_sync/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 1744e9da0..a6c89887c 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -238,7 +238,7 @@ def get_output_path(node): for target in n.targets: if isinstance(target, ast.Name) and target.id == "__CROSS_SYNC_OUTPUT__": # keep the output path - return n.value.value.replace(".", "/") + ".py" + return n.value.s.replace(".", "/") + ".py" def visit_Module(self, node): # look for __CROSS_SYNC_OUTPUT__ Assign statement From ea571bd63e16e84309f1da6fffcd90a496e9e295 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 11:19:28 -0700 Subject: [PATCH 277/360] simplified transformers. Removed CrossSyncMethodHandler --- .cross_sync/generate.py | 6 +- .cross_sync/transformers.py | 104 ++++--- .../data/_sync/cross_sync/_decorators.py | 35 ++- .../test_cases/cross_sync_files.yaml | 255 ++++++++++++++++-- .../test_cases/cross_sync_methods.yaml | 144 ---------- .../system/cross_sync/test_cross_sync_e2e.py | 3 +- 6 files changed, 329 insertions(+), 218 deletions(-) delete mode 100644 tests/system/cross_sync/test_cases/cross_sync_methods.yaml diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index 5c130079f..321ad82fa 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -18,7 +18,7 @@ Entrypoint for initiating an async -> sync conversion using CrossSync Finds all python files rooted in a given directory, and uses -transformers.CrossSyncFileHandler to handle any files marked with +transformers.CrossSyncFileProcessor to handle any files marked with __CROSS_SYNC_OUTPUT__ """ @@ -73,13 +73,13 @@ def render(self, with_black=True, save_to_disk: bool = False) -> str: def convert_files_in_dir(directory: str) -> set[CrossSyncOutputFile]: import glob - from transformers import CrossSyncFileHandler + from transformers import CrossSyncFileProcessor # find all python files in the directory files = glob.glob(directory + "/**/*.py", recursive=True) # keep track of the output files pointed to by the annotated classes artifacts: set[CrossSyncOutputFile] = set() - file_transformer = CrossSyncFileHandler() + file_transformer = CrossSyncFileProcessor() # run each file through ast transformation to find all annotated classes for file_path in files: ast_tree = ast.parse(open(file_path).read()) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index a6c89887c..b4ffe7ba6 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -27,6 +27,7 @@ from __future__ import annotations import ast +import copy import sys # add cross_sync to path @@ -197,55 +198,39 @@ def visit_AsyncFor(self, node): return self.generic_visit(node) -class CrossSyncMethodDecoratorHandler(ast.NodeTransformer): +class CrossSyncFileProcessor(ast.NodeTransformer): """ - Visits each method in a class, and handles any CrossSync decorators found + Visits a file, looking for __CROSS_SYNC_OUTPUT__ annotations + + If found, the file is processed with the following steps: + - Strip out asyncio keywords within CrossSync.rm_aio calls + - transform classes and methods annotated with CrossSync decorators + - classes not marked with @CrossSync.export are discarded in sync version + - statements behind CrossSync.is_async conditional branches are removed + - Replace remaining CrossSync statements with corresponding CrossSync._Sync calls + - save changes in an output file at path specified by __CROSS_SYNC_OUTPUT__ """ + FILE_ANNOTATION = "__CROSS_SYNC_OUTPUT__" - def visit_FunctionDef(self, node): - return self.visit_AsyncFunctionDef(node) - - def visit_AsyncFunctionDef(self, node): - try: - if hasattr(node, "decorator_list"): - found_list, node.decorator_list = node.decorator_list, [] - for decorator in found_list: - try: - handler = AstDecorator.get_for_node(decorator) - node = handler.sync_ast_transform(node, globals()) - if node is None: - return None - # recurse to any nested functions - node = self.generic_visit(node) - except ValueError: - # keep unknown decorators - node.decorator_list.append(decorator) - continue - return node - except ValueError as e: - raise ValueError(f"node {node.name} failed") from e - - -class CrossSyncFileHandler(ast.NodeTransformer): - """ - Visit each file, and process CrossSync classes if found - """ - - @staticmethod - def get_output_path(node): + def get_output_path(self, node): for i, n in enumerate(node.body): if isinstance(n, ast.Assign): for target in n.targets: - if isinstance(target, ast.Name) and target.id == "__CROSS_SYNC_OUTPUT__": - # keep the output path - return n.value.s.replace(".", "/") + ".py" + if isinstance(target, ast.Name) and target.id == self.FILE_ANNOTATION: + # return the output path + return n.value.value.replace(".", "/") + ".py" def visit_Module(self, node): # look for __CROSS_SYNC_OUTPUT__ Assign statement output_path = self.get_output_path(node) if output_path: # if found, process the file - return self.generic_visit(node) + converted = self.generic_visit(node) + # strip out CrossSync.rm_aio calls + converted = RmAioFunctions().visit(converted) + # replace CrossSync statements + converted = SymbolReplacer({"CrossSync": "CrossSync._Sync_Impl"}).visit(converted) + return converted else: # not cross_sync file. Return None return None @@ -260,18 +245,55 @@ def visit_ClassDef(self, node): handler = AstDecorator.get_for_node(decorator) if isinstance(handler, ExportSync): # transformation is handled in sync_ast_transform method of the decorator - return handler.sync_ast_transform(node, globals()) + after_export = handler.sync_ast_transform(node, globals()) + return self.generic_visit(after_export) except ValueError: # not cross_sync decorator continue - # cross_sync decorator not found - return node + # cross_sync decorator not found. Drop from sync version + return None def visit_If(self, node): """ remove CrossSync.is_async branches from top-level if statements """ if isinstance(node.test, ast.Attribute) and isinstance(node.test.value, ast.Name) and node.test.value.id == "CrossSync" and node.test.attr == "is_async": - return node.orelse + return [self.generic_visit(n) for n in node.orelse] return self.generic_visit(node) + def visit_Assign(self, node): + """ + strip out __CROSS_SYNC_OUTPUT__ assignments + """ + if isinstance(node.targets[0], ast.Name) and node.targets[0].id == self.FILE_ANNOTATION: + return None + return self.generic_visit(node) + + def visit_FunctionDef(self, node): + """ + Visit any sync methods marked with CrossSync decorators + """ + return self.visit_AsyncFunctionDef(node) + + def visit_AsyncFunctionDef(self, node): + """ + Visit and transform any async methods marked with CrossSync decorators + """ + try: + if hasattr(node, "decorator_list"): + found_list, node.decorator_list = node.decorator_list, [] + for decorator in found_list: + try: + handler = AstDecorator.get_for_node(decorator) + node = handler.sync_ast_transform(node, globals()) + if node is None: + return None + # recurse to any nested functions + node = self.generic_visit(node) + except ValueError: + # keep unknown decorators + node.decorator_list.append(decorator) + continue + return self.generic_visit(node) + except ValueError as e: + raise ValueError(f"node {node.name} failed") from e diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index 35219bee1..a7358b0c4 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -251,17 +251,11 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): keywords=[], ) ) - # convert class contents - wrapped_node = transformers_globals["RmAioFunctions"]().visit(wrapped_node) - replace_dict = self.replace_symbols or {} - replace_dict.update({"CrossSync": "CrossSync._Sync_Impl"}) - wrapped_node = transformers_globals["SymbolReplacer"](replace_dict).visit( - wrapped_node - ) - # visit CrossSync method decorators - wrapped_node = transformers_globals["CrossSyncMethodDecoratorHandler"]().visit( - wrapped_node - ) + # replace symbols if specified + if self.replace_symbols: + wrapped_node = transformers_globals["SymbolReplacer"](self.replace_symbols).visit( + wrapped_node + ) # update docstring if specified if self.sync_docstring_format_vars: docstring = ast.get_docstring(wrapped_node) @@ -394,9 +388,24 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): """ convert async to sync """ + import ast + # always convert method to sync + converted = ast.copy_location( + ast.FunctionDef( + wrapped_node.name, + wrapped_node.args, + wrapped_node.body, + wrapped_node.decorator_list + if hasattr(wrapped_node, "decorator_list") + else [], + wrapped_node.returns if hasattr(wrapped_node, "returns") else None, + ), + wrapped_node, + ) + # convert entire body to sync if rm_aio is set if self.rm_aio: - wrapped_node = transformers_globals["AsyncToSync"]().visit(wrapped_node) - return wrapped_node + converted = transformers_globals["AsyncToSync"]().visit(converted) + return converted class PytestFixture(AstDecorator): diff --git a/tests/system/cross_sync/test_cases/cross_sync_files.yaml b/tests/system/cross_sync/test_cases/cross_sync_files.yaml index a49b8189e..f978cfc17 100644 --- a/tests/system/cross_sync/test_cases/cross_sync_files.yaml +++ b/tests/system/cross_sync/test_cases/cross_sync_files.yaml @@ -6,7 +6,7 @@ tests: pass transformers: - - name: CrossSyncFileHandler + - name: CrossSyncFileProcessor after: null - description: "CrossSync.export_sync with default sync_name" @@ -18,7 +18,7 @@ tests: pass transformers: - - name: CrossSyncFileHandler + - name: CrossSyncFileProcessor after: | class MyClass: @@ -34,7 +34,7 @@ tests: pass transformers: - - name: CrossSyncFileHandler + - name: CrossSyncFileProcessor after: | class MyClass: @@ -53,7 +53,7 @@ tests: self.base = base transformers: - - name: CrossSyncFileHandler + - name: CrossSyncFileProcessor after: | class MyClass(ParentB): @@ -71,7 +71,7 @@ tests: """This is a {type} class.""" transformers: - - name: CrossSyncFileHandler + - name: CrossSyncFileProcessor after: | class MyClass: """This is a sync class.""" @@ -99,7 +99,7 @@ tests: pass transformers: - - name: CrossSyncFileHandler + - name: CrossSyncFileProcessor after: | @some_other_decorator class MyClass: @@ -115,17 +115,36 @@ tests: def fixture(self): pass - - description: "CrossSync.export_sync with nested classes" + - description: "CrossSync.export_sync with nested classes drop" before: | __CROSS_SYNC_OUTPUT__ = "out.path" - @CrossSync.export_sync(sync_name="MyClass", replace_symbols={"AsyncBase": "SyncBase"}) + @CrossSync.export_sync(sync_name="MyClass") class MyAsyncClass: class NestedAsyncClass: async def nested_method(self, base: AsyncBase): pass - @CrossSync.drop_method - async def drop_this_method(self): + @CrossSync.convert + async def use_nested(self): + nested = self.NestedAsyncClass() + CrossSync.rm_aio(await nested.nested_method()) + transformers: + - name: CrossSyncFileProcessor + after: | + class MyClass: + + def use_nested(self): + nested = self.NestedAsyncClass() + nested.nested_method() + + - description: "CrossSync.export_sync with nested classes" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.export_sync(sync_name="MyClass", replace_symbols={"AsyncBase": "SyncBase"}) + class MyAsyncClass: + @CrossSync.export_sync + class NestedClass: + async def nested_method(self, base: AsyncBase): pass @CrossSync.convert @@ -133,11 +152,11 @@ tests: nested = self.NestedAsyncClass() CrossSync.rm_aio(await nested.nested_method()) transformers: - - name: CrossSyncFileHandler + - name: CrossSyncFileProcessor after: | class MyClass: - class NestedAsyncClass: + class NestedClass: async def nested_method(self, base: SyncBase): pass @@ -158,12 +177,12 @@ tests: pass transformers: - - name: CrossSyncFileHandler + - name: CrossSyncFileProcessor after: | @CrossSync._Sync_Impl.add_mapping_decorator("MyClass") class MyClass: - async def my_method(self): + async def my_method(self): pass - description: "CrossSync.export_sync with CrossSync calls" @@ -177,10 +196,216 @@ tests: CrossSync.rm_aio(await CrossSync.yield_to_event_loop()) transformers: - - name: CrossSyncFileHandler + - name: CrossSyncFileProcessor after: | class MyClass: def my_method(self): with CrossSync._Sync_Impl.Condition() as c: CrossSync._Sync_Impl.yield_to_event_loop() + + - description: "Convert async method with @CrossSync.convert" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.convert + async def my_method(self, arg): + pass + transformers: [CrossSyncFileProcessor] + after: | + def my_method(self, arg): + pass + + - description: "Convert async method with custom sync name" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.convert(sync_name="sync_method") + async def async_method(self, arg): + return await self.helper(arg) + transformers: [CrossSyncFileProcessor] + after: | + def sync_method(self, arg): + return await self.helper(arg) + + - description: "Convert async method with rm_aio=True" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.convert(rm_aio=True) + async def async_method(self): + async with self.lock: + async for item in self.items: + await self.process(item) + transformers: [CrossSyncFileProcessor] + after: | + def async_method(self): + with self.lock: + for item in self.items: + self.process(item) + + - description: "Drop method from sync version" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + def keep_method(self): + pass + + @CrossSync.drop_method + async def async_only_method(self): + await self.async_operation() + transformers: [CrossSyncFileProcessor] + after: | + def keep_method(self): + pass + + - description: "Convert.pytest" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.pytest + async def test_async_function(): + result = await async_operation() + assert result == expected_value + transformers: [CrossSyncFileProcessor] + after: | + def test_async_function(): + result = async_operation() + assert result == expected_value + + - description: "CrossSync.pytest with rm_aio=False" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.pytest(rm_aio=False) + async def test_partial_async(): + async with context_manager(): + result = await async_function() + assert result == expected_value + transformers: [CrossSyncFileProcessor] + after: | + def test_partial_async(): + async with context_manager(): + result = await async_function() + assert result == expected_value + + - description: "Convert async pytest fixture" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.pytest_fixture + @CrossSync.convert(rm_aio=True) + async def my_fixture(): + resource = await setup_resource() + yield resource + await cleanup_resource(resource) + transformers: [CrossSyncFileProcessor] + after: | + @pytest.fixture() + def my_fixture(): + resource = setup_resource() + yield resource + cleanup_resource(resource) + + - description: "Convert pytest fixture with custom parameters" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.pytest_fixture(scope="module", autouse=True) + def my_fixture(): + resource = setup_resource() + yield resource + cleanup_resource(resource) + transformers: [CrossSyncFileProcessor] + after: | + @pytest.fixture(scope="module", autouse=True) + def my_fixture(): + resource = setup_resource() + yield resource + cleanup_resource(resource) + + - description: "Convert method with multiple stacked decorators" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.convert(sync_name="sync_multi_decorated") + @CrossSync.pytest + @some_other_decorator + async def async_multi_decorated(self, arg): + result = await self.async_operation(arg) + return result + transformers: [CrossSyncFileProcessor] + after: | + @some_other_decorator + def sync_multi_decorated(self, arg): + result = self.async_operation(arg) + return result + + - description: "Convert method with multiple stacked decorators in class" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.export_sync + class MyClass: + @CrossSync.convert(sync_name="sync_multi_decorated") + @CrossSync.pytest + @some_other_decorator + async def async_multi_decorated(self, arg): + result = await self.async_operation(arg) + return result + transformers: [CrossSyncFileProcessor] + after: | + class MyClass: + + @some_other_decorator + def sync_multi_decorated(self, arg): + result = self.async_operation(arg) + return result + + - description: "Convert method with stacked decorators including rm_aio" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.convert(rm_aio=True) + @CrossSync.pytest_fixture(scope="function") + @another_decorator + async def async_fixture_with_context(): + async with some_async_context(): + resource = await setup_async_resource() + yield resource + await cleanup_async_resource(resource) + transformers: [CrossSyncFileProcessor] + after: | + @pytest.fixture(scope="function") + @another_decorator + def async_fixture_with_context(): + with some_async_context(): + resource = setup_async_resource() + yield resource + cleanup_async_resource(resource) + + - description: "Handle CrossSync.is_async conditional" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + if CrossSync.is_async: + import a + else: + import b + + def my_method(self): + if CrossSync.is_async: + return "async version" + else: + return "sync version" + transformers: [CrossSyncFileProcessor] + after: | + import b + + def my_method(self): + return "sync version" + + - description: "Replace CrossSync symbols" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + CrossSync.sleep(1) + @CrossSync.export_sync + class MyClass: + event = CrossSync.Event() + def my_method(self): + return CrossSync.some_function() + transformers: [CrossSyncFileProcessor] + after: | + CrossSync._Sync_Impl.sleep(1) + class MyClass: + event = CrossSync._Sync_Impl.Event() + def my_method(self): + return CrossSync._Sync_Impl.some_function() diff --git a/tests/system/cross_sync/test_cases/cross_sync_methods.yaml b/tests/system/cross_sync/test_cases/cross_sync_methods.yaml deleted file mode 100644 index ca2222a52..000000000 --- a/tests/system/cross_sync/test_cases/cross_sync_methods.yaml +++ /dev/null @@ -1,144 +0,0 @@ -tests: - - description: "Convert async method with @CrossSync.convert" - before: | - @CrossSync.convert - async def my_method(self, arg): - pass - transformers: [CrossSyncMethodDecoratorHandler] - after: | - def my_method(self, arg): - pass - - - description: "Convert async method with custom sync name" - before: | - @CrossSync.convert(sync_name="sync_method") - async def async_method(self, arg): - return await self.helper(arg) - transformers: [CrossSyncMethodDecoratorHandler] - after: | - def sync_method(self, arg): - return await self.helper(arg) - - - description: "Convert async method with symbol replacement" - before: | - @CrossSync.convert(replace_symbols={"old": "new"}) - async def my_method(self): - old = 1 - transformers: [CrossSyncMethodDecoratorHandler] - after: | - def my_method(self): - new = 1 - - - description: "Convert async method with rm_aio=True" - before: | - @CrossSync.convert(rm_aio=True) - async def async_method(self): - async with self.lock: - async for item in self.items: - await self.process(item) - transformers: [CrossSyncMethodDecoratorHandler] - after: | - def async_method(self): - with self.lock: - for item in self.items: - self.process(item) - - - description: "Convert async method with docstring formatting" - before: | - @CrossSync.convert(docstring_format_vars={"mode": ("async", "sync")}) - async def async_method(self): - """This is a {mode} method.""" - transformers: [CrossSyncMethodDecoratorHandler] - after: | - def async_method(self): - """This is a sync method.""" - - - description: "Drop method from sync version" - before: | - def keep_method(self): - pass - - @CrossSync.drop_method - async def async_only_method(self): - await self.async_operation() - transformers: [CrossSyncMethodDecoratorHandler] - after: | - def keep_method(self): - pass - - - description: "Convert.pytest" - before: | - @CrossSync.pytest - async def test_async_function(): - result = await async_operation() - assert result == expected_value - transformers: [CrossSyncMethodDecoratorHandler] - after: | - def test_async_function(): - result = async_operation() - assert result == expected_value - - - description: "CrossSync.pytest with rm_aio=False" - before: | - @CrossSync.pytest(rm_aio=False) - async def test_partial_async(): - async with context_manager(): - result = await async_function() - assert result == expected_value - transformers: [CrossSyncMethodDecoratorHandler] - after: | - async def test_partial_async(): - async with context_manager(): - result = await async_function() - assert result == expected_value - - - description: "Convert pytest fixture with custom parameters" - before: | - @CrossSync.pytest_fixture(scope="module", autouse=True) - async def async_fixture(): - resource = await setup_resource() - yield resource - await cleanup_resource(resource) - transformers: [CrossSyncMethodDecoratorHandler] - after: | - @pytest.fixture(scope="module", autouse=True) - async def async_fixture(): - resource = await setup_resource() - yield resource - await cleanup_resource(resource) - - - description: "Convert method with multiple stacked decorators" - before: | - @CrossSync.convert(sync_name="sync_multi_decorated") - @CrossSync.pytest - @some_other_decorator - async def async_multi_decorated(self, arg): - result = await self.async_operation(arg) - return result - transformers: [CrossSyncMethodDecoratorHandler] - after: | - @some_other_decorator - def sync_multi_decorated(self, arg): - result = self.async_operation(arg) - return result - - - description: "Convert method with stacked decorators including rm_aio" - before: | - @CrossSync.convert(rm_aio=True) - @CrossSync.pytest_fixture(scope="function") - @another_decorator - async def async_fixture_with_context(): - async with some_async_context(): - resource = await setup_async_resource() - yield resource - await cleanup_async_resource(resource) - transformers: [CrossSyncMethodDecoratorHandler] - after: | - @pytest.fixture(scope="function") - @another_decorator - def async_fixture_with_context(): - with some_async_context(): - resource = setup_async_resource() - yield resource - cleanup_async_resource(resource) - diff --git a/tests/system/cross_sync/test_cross_sync_e2e.py b/tests/system/cross_sync/test_cross_sync_e2e.py index bd08ed6cb..ab0e70162 100644 --- a/tests/system/cross_sync/test_cross_sync_e2e.py +++ b/tests/system/cross_sync/test_cross_sync_e2e.py @@ -14,8 +14,7 @@ SymbolReplacer, AsyncToSync, RmAioFunctions, - CrossSyncMethodDecoratorHandler, - CrossSyncFileHandler, + CrossSyncFileProcessor, ) From e42655bce9c49ddc07112b3a01fc40c9d57a1bdf Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 12:17:52 -0700 Subject: [PATCH 278/360] simplified Convert decorator --- .../data/_sync/cross_sync/_decorators.py | 93 +++++++++---------- .../test_cases/cross_sync_files.yaml | 16 ++++ .../data/_sync/test_cross_sync_decorators.py | 78 ++++++++-------- 3 files changed, 99 insertions(+), 88 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index a7358b0c4..d5f1ac107 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -140,13 +140,25 @@ def get_for_node(cls, node: ast.Call | ast.Attribute | ast.Name) -> "AstDecorato ) # convert to standardized representation formatted_name = decorator_name.replace("_", "").lower() - for subclass in cls.__subclasses__(): + for subclass in cls.get_subclasses(): if subclass.__name__.lower() == formatted_name: return subclass(*got_args, **got_kwargs) raise ValueError(f"Unknown decorator encountered: {decorator_name}") else: raise ValueError("Not a CrossSync decorator") + @classmethod + def get_subclasses(cls) -> list[type[AstDecorator]]: + """ + Get all subclasses of AstDecorator + + Returns: + list of all subclasses of AstDecorator + """ + for subclass in cls.__subclasses__(): + yield from subclass.get_subclasses() + yield subclass + @classmethod def _convert_ast_to_py(cls, ast_node: ast.expr | None) -> Any: """ @@ -178,7 +190,10 @@ class ExportSync(AstDecorator): sync_name: use a new name for the sync class replace_symbols: a dict of symbols and replacements to use when generating sync class docstring_format_vars: a dict of variables to replace in the docstring - add_mapping_for_name: when given, will add a new attribute to CrossSync, so the original class and its sync version can be accessed from CrossSync. + rm_aio: if True, automatically strip all asyncio keywords from method. If false, + only keywords wrapped in CrossSync.rm_aio() calls to be removed. + add_mapping_for_name: when given, will add a new attribute to CrossSync, + so the original class and its sync version can be accessed from CrossSync. """ def __init__( @@ -187,6 +202,7 @@ def __init__( *, replace_symbols: dict[str, str] | None = None, docstring_format_vars: dict[str, tuple[str, str]] | None = None, + rm_aio: bool = False, add_mapping_for_name: str | None = None, ): self.sync_name = sync_name @@ -198,6 +214,7 @@ def __init__( self.sync_docstring_format_vars = { k: v[1] for k, v in docstring_format_vars.items() } + self.rm_aio = rm_aio self.add_mapping_for_name = add_mapping_for_name def async_decorator(self): @@ -206,6 +223,10 @@ def async_decorator(self): """ from .cross_sync import CrossSync + if not self.add_mapping_for_name and not self.async_docstring_format_vars: + # return None if no changes needed + return None + new_mapping = self.add_mapping_for_name def decorator(cls): @@ -236,6 +257,9 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): ] else: wrapped_node.decorator_list = [] + # strip async keywords if specified + if self.rm_aio: + wrapped_node = transformers_globals["AsyncToSync"]().visit(wrapped_node) # add mapping decorator if needed if self.add_mapping_for_name: wrapped_node.decorator_list.append( @@ -253,9 +277,9 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): ) # replace symbols if specified if self.replace_symbols: - wrapped_node = transformers_globals["SymbolReplacer"](self.replace_symbols).visit( - wrapped_node - ) + wrapped_node = transformers_globals["SymbolReplacer"]( + self.replace_symbols + ).visit(wrapped_node) # update docstring if specified if self.sync_docstring_format_vars: docstring = ast.get_docstring(wrapped_node) @@ -266,7 +290,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): return wrapped_node -class Convert(AstDecorator): +class Convert(ExportSync): """ Method decorator to mark async methods to be converted to sync methods @@ -281,22 +305,19 @@ class Convert(AstDecorator): def __init__( self, - *, sync_name: str | None = None, + *, replace_symbols: dict[str, str] | None = None, docstring_format_vars: dict[str, tuple[str, str]] | None = None, rm_aio: bool = False, ): - self.sync_name = sync_name - self.replace_symbols = replace_symbols - docstring_format_vars = docstring_format_vars or {} - self.async_docstring_format_vars = { - k: v[0] for k, v in docstring_format_vars.items() - } - self.sync_docstring_format_vars = { - k: v[1] for k, v in docstring_format_vars.items() - } - self.rm_aio = rm_aio + super().__init__( + sync_name=sync_name, + replace_symbols=replace_symbols, + docstring_format_vars=docstring_format_vars, + rm_aio=rm_aio, + add_mapping_for_name=None, + ) def sync_ast_transform(self, wrapped_node, transformers_globals): """ @@ -305,7 +326,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): import ast # replace async function with sync function - wrapped_node = ast.copy_location( + converted = ast.copy_location( ast.FunctionDef( wrapped_node.name, wrapped_node.args, @@ -317,39 +338,8 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): ), wrapped_node, ) - # update name if specified - if self.sync_name: - wrapped_node.name = self.sync_name - # strip async keywords if specified - if self.rm_aio: - wrapped_node = transformers_globals["AsyncToSync"]().visit(wrapped_node) - # update arbitrary symbols if specified - if self.replace_symbols: - replacer = transformers_globals["SymbolReplacer"] - wrapped_node = replacer(self.replace_symbols).visit(wrapped_node) - # update docstring if specified - if self.sync_docstring_format_vars: - docstring = ast.get_docstring(wrapped_node) - if docstring: - wrapped_node.body[0].value = ast.Constant( - value=docstring.format(**self.sync_docstring_format_vars) - ) - return wrapped_node - - def async_decorator(self): - """ - If docstring_format_vars are provided, update the docstring of the async method - """ - - if self.async_docstring_format_vars: - - def decorator(f): - f.__doc__ = f.__doc__.format(**self.async_docstring_format_vars) - return f - - return decorator - else: - return None + # transform based on arguments + return super().sync_ast_transform(converted, transformers_globals) class DropMethod(AstDecorator): @@ -389,6 +379,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): convert async to sync """ import ast + # always convert method to sync converted = ast.copy_location( ast.FunctionDef( diff --git a/tests/system/cross_sync/test_cases/cross_sync_files.yaml b/tests/system/cross_sync/test_cases/cross_sync_files.yaml index f978cfc17..9c8627657 100644 --- a/tests/system/cross_sync/test_cases/cross_sync_files.yaml +++ b/tests/system/cross_sync/test_cases/cross_sync_files.yaml @@ -185,6 +185,22 @@ tests: async def my_method(self): pass + - description: "CrossSync.export_sync with rm_aio" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.export_sync(rm_aio=True) + class MyClass: + async def my_method(self): + async for item in self.items: + await self.process(item) + transformers: [CrossSyncFileProcessor] + after: | + class MyClass: + + def my_method(self): + for item in self.items: + self.process(item) + - description: "CrossSync.export_sync with CrossSync calls" before: | __CROSS_SYNC_OUTPUT__ = "out.path" diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_sync/test_cross_sync_decorators.py index 6c817fd9d..727d5a599 100644 --- a/tests/unit/data/_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_sync/test_cross_sync_decorators.py @@ -51,24 +51,28 @@ def test_ctor_defaults(self): assert instance.add_mapping_for_name is None assert instance.async_docstring_format_vars == {} assert instance.sync_docstring_format_vars == {} + assert instance.rm_aio is False def test_ctor(self): sync_name = "sync_name" replace_symbols = {"a": "b"} docstring_format_vars = {"A": (1, 2)} add_mapping_for_name = "test_name" + rm_aio = True instance = self._get_class()( sync_name, replace_symbols=replace_symbols, docstring_format_vars=docstring_format_vars, add_mapping_for_name=add_mapping_for_name, + rm_aio=rm_aio, ) assert instance.sync_name is sync_name assert instance.replace_symbols is replace_symbols assert instance.add_mapping_for_name is add_mapping_for_name assert instance.async_docstring_format_vars == {"A": 1} assert instance.sync_docstring_format_vars == {"A": 2} + assert instance.rm_aio is rm_aio def test_class_decorator(self): """ @@ -215,26 +219,6 @@ def test_sync_ast_transform_add_docstring_format( assert isinstance(result.body[0].value, ast.Constant) assert result.body[0].value.value == expected - def test_sync_ast_transform_call_cross_sync_transforms(self): - """ - Should use transformers_globals to call some extra transforms on class: - - RmAioFunctions - - SymbolReplacer - - CrossSyncMethodDecoratorHandler - """ - decorator = self._get_class()("path.to.SyncClass") - mock_node = ast.ClassDef(name="AsyncClass", bases=[], keywords=[], body=[]) - - transformers_globals = { - "RmAioFunctions": mock.Mock(), - "SymbolReplacer": mock.Mock(), - "CrossSyncMethodDecoratorHandler": mock.Mock(), - } - decorator.sync_ast_transform(mock_node, transformers_globals) - # ensure each transformer was called - for transformer in transformers_globals.values(): - assert transformer.call_count == 1 - def test_sync_ast_transform_replace_symbols(self, globals_mock): """ SymbolReplacer should be called with replace_symbols @@ -253,8 +237,19 @@ def test_sync_ast_transform_replace_symbols(self, globals_mock): assert "a" in found_dict for k, v in replace_symbols.items(): assert found_dict[k] == v - # should also add CrossSync replacement - assert found_dict["CrossSync"] == "CrossSync._Sync_Impl" + + def test_sync_ast_transform_rmaio_calls_async_to_sync(self): + """ + Should call AsyncToSync if rm_aio is set + """ + decorator = self._get_class()(rm_aio=True) + mock_node = ast.ClassDef(name="AsyncClass", bases=[], keywords=[], body=[]) + async_to_sync_mock = mock.Mock() + async_to_sync_mock.visit.side_effect = lambda x: x + globals_mock = {"AsyncToSync": lambda: async_to_sync_mock} + + decorator.sync_ast_transform(mock_node, globals_mock) + assert async_to_sync_mock.visit.call_count == 1 class TestConvertDecorator: @@ -352,7 +347,7 @@ def test_sync_ast_transform_replaces_name(self, globals_mock): assert isinstance(result, ast.FunctionDef) assert result.name == "new_method_name" - def test_sync_ast_transform_calls_async_to_sync(self): + def test_sync_ast_transform_rmaio_calls_async_to_sync(self): """ Should call AsyncToSync if rm_aio is set """ @@ -459,34 +454,43 @@ def test_decorator_functionality(self): """ Should wrap the class with pytest.mark.asyncio """ - unwrapped_class = mock.Mock - wrapped_class = self._get_class().decorator(unwrapped_class) - assert wrapped_class == pytest.mark.asyncio(unwrapped_class) + unwrapped_fn = mock.Mock + wrapped_class = self._get_class().decorator(unwrapped_fn) + assert wrapped_class == pytest.mark.asyncio(unwrapped_fn) def test_sync_ast_transform(self): """ - Should be no-op if rm_aio is not set + If rm_aio is True (default), should call AsyncToSync on the class """ - decorator = self._get_class()(rm_aio=False) + decorator = self._get_class()() + mock_node = ast.AsyncFunctionDef( + name="AsyncMethod", args=ast.arguments(), body=[] + ) - input_obj = object() - result = decorator.sync_ast_transform(input_obj, {}) + async_to_sync_mock = mock.Mock() + async_to_sync_mock.visit.side_effect = lambda x: x + globals_mock = {"AsyncToSync": lambda: async_to_sync_mock} - assert result is input_obj + transformed = decorator.sync_ast_transform(mock_node, globals_mock) + assert async_to_sync_mock.visit.call_count == 1 + assert isinstance(transformed, ast.FunctionDef) - def test_sync_ast_transform_rm_aio(self): + def test_sync_ast_transform_no_rm_aio(self): """ - If rm_aio is set, should call AsyncToSync on the class + if rm_aio is False, should remove the async keyword from the method """ - decorator = self._get_class()() - mock_node = ast.ClassDef(name="AsyncClass", bases=[], keywords=[], body=[]) + decorator = self._get_class()(rm_aio=False) + mock_node = ast.AsyncFunctionDef( + name="AsyncMethod", args=ast.arguments(), body=[] + ) async_to_sync_mock = mock.Mock() async_to_sync_mock.visit.return_value = mock_node globals_mock = {"AsyncToSync": lambda: async_to_sync_mock} - decorator.sync_ast_transform(mock_node, globals_mock) - assert async_to_sync_mock.visit.call_count == 1 + transformed = decorator.sync_ast_transform(mock_node, globals_mock) + assert async_to_sync_mock.visit.call_count == 0 + assert isinstance(transformed, ast.FunctionDef) class TestPytestFixtureDecorator: From e9723859c1343bbd413162f1ac7f38adcb4c8b15 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 12:20:23 -0700 Subject: [PATCH 279/360] fixed 3.7 tests --- tests/unit/data/_async/test_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 88bc6ca7b..6c49ca0da 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -33,11 +33,11 @@ from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule # try/except added for compatibility with python < 3.8 -from unittest import mock - try: + from unittest import mock from unittest.mock import AsyncMock # type: ignore except ImportError: # pragma: NO COVER + import mock # type: ignore from mock import AsyncMock # type: ignore VENEER_HEADER_REGEX = re.compile( From 81a06f8d54114fd2ae73162b16967c78a33e9453 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 12:24:30 -0700 Subject: [PATCH 280/360] fixed mypy issues --- google/cloud/bigtable/data/_sync/cross_sync/_decorators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index d5f1ac107..965765161 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -16,7 +16,7 @@ Each AstDecorator class is used through @CrossSync. """ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable if TYPE_CHECKING: import ast @@ -148,7 +148,7 @@ def get_for_node(cls, node: ast.Call | ast.Attribute | ast.Name) -> "AstDecorato raise ValueError("Not a CrossSync decorator") @classmethod - def get_subclasses(cls) -> list[type[AstDecorator]]: + def get_subclasses(cls) -> Iterable[type["AstDecorator"]]: """ Get all subclasses of AstDecorator From db8c1561568751eb2aac4905e62d35ec366c62f6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 15:18:12 -0700 Subject: [PATCH 281/360] extract header from code --- .cross_sync/generate.py | 46 ++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index 321ad82fa..618dfbc6c 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -23,11 +23,27 @@ """ +def extract_header_comments(file_path) -> str: + """ + Extract the file header. Header is defined as the top-level + comments before any code or imports + """ + header = [] + with open(file_path, "r") as f: + for line in f: + if line.startswith("#") or line.strip() == "": + header.append(line) + else: + break + return "".join(header) + + class CrossSyncOutputFile: - def __init__(self, file_path: str, ast_tree): - self.file_path = file_path + def __init__(self, output_path: str, ast_tree, header: str | None = None): + self.output_path = output_path self.tree = ast_tree + self.header = header or "" def render(self, with_black=True, save_to_disk: bool = False) -> str: """ @@ -37,24 +53,7 @@ def render(self, with_black=True, save_to_disk: bool = False) -> str: with_black: whether to run the output through black before returning save_to_disk: whether to write the output to the file path """ - header = ( - "# Copyright 2024 Google LLC\n" - "#\n" - '# Licensed under the Apache License, Version 2.0 (the "License");\n' - "# you may not use this file except in compliance with the License.\n" - "# You may obtain a copy of the License at\n" - "#\n" - "# http://www.apache.org/licenses/LICENSE-2.0\n" - "#\n" - "# Unless required by applicable law or agreed to in writing, software\n" - '# distributed under the License is distributed on an "AS IS" BASIS,\n' - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" - "# See the License for the specific language governing permissions and\n" - "# limitations under the License.\n" - "#\n" - "# This file is automatically generated by CrossSync. Do not edit manually.\n" - ) - full_str = header + ast.unparse(self.tree) + full_str = self.header + ast.unparse(self.tree) if with_black: import black # type: ignore import autoflake # type: ignore @@ -65,8 +64,8 @@ def render(self, with_black=True, save_to_disk: bool = False) -> str: ) if save_to_disk: import os - os.makedirs(os.path.dirname(self.file_path), exist_ok=True) - with open(self.file_path, "w") as f: + os.makedirs(os.path.dirname(self.output_path), exist_ok=True) + with open(self.output_path, "w") as f: f.write(full_str) return full_str @@ -87,7 +86,8 @@ def convert_files_in_dir(directory: str) -> set[CrossSyncOutputFile]: if output_path is not None: # contains __CROSS_SYNC_OUTPUT__ annotation converted_tree = file_transformer.visit(ast_tree) - artifacts.add(CrossSyncOutputFile(output_path, converted_tree)) + header = extract_header_comments(file_path) + artifacts.add(CrossSyncOutputFile(output_path, converted_tree, header)) # return set of output artifacts return artifacts From d0dcbff5aaccaa2845718631f9b685bf635041df Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 15:52:52 -0700 Subject: [PATCH 282/360] fix mock import --- tests/unit/data/_async/test_client.py | 7 +------ tests/unit/data/_async/test_mutations_batcher.py | 8 +------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 497326175..bffd4240c 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -19,6 +19,7 @@ import sys import pytest +import mock from google.cloud.bigtable.data import mutations from google.auth.credentials import AnonymousCredentials @@ -35,12 +36,6 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -# try/except added for compatibility with python < 3.8 -try: - from unittest import mock -except ImportError: # pragma: NO COVER - import mock # type: ignore - if CrossSync.is_async: from google.api_core import grpc_helpers_async from google.cloud.bigtable.data._async.client import TableAsync diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 0361da5eb..cdfdc5ab7 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +import mock import asyncio import time import google.api_core.exceptions as core_exceptions @@ -22,13 +23,6 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -# try/except added for compatibility with python < 3.8 -try: - from unittest import mock -except ImportError: # pragma: NO COVER - import mock # type: ignore - - __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_mutations_batcher" From 22d3e40e42f3b3fc474c934cd14f94487f4090a0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 16:58:40 -0700 Subject: [PATCH 283/360] fix test_sync_up_to_date test --- tests/unit/data/test_sync_up_to_date.py | 45 +++++++++++++++++++++---- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py index d72c70988..d170638c4 100644 --- a/tests/unit/data/test_sync_up_to_date.py +++ b/tests/unit/data/test_sync_up_to_date.py @@ -15,6 +15,8 @@ import sys import hashlib import pytest +import ast +import re from difflib import unified_diff # add cross_sync to path @@ -23,24 +25,30 @@ cross_sync_path = os.path.join(repo_root, ".cross_sync") sys.path.append(cross_sync_path) -from generate import convert_files_in_dir # noqa: E402 +from generate import convert_files_in_dir, CrossSyncOutputFile # noqa: E402 + +sync_files = list(convert_files_in_dir(repo_root)) @pytest.mark.skipif( sys.version_info < (3, 9), reason="ast.unparse is only available in 3.9+" ) @pytest.mark.parametrize( - "artifact", convert_files_in_dir(repo_root), ids=lambda a: a.file_path + "sync_file", sync_files, ids=lambda f: f.output_path ) -def test_sync_up_to_date(artifact): +def test_sync_up_to_date(sync_file): """ Generate a fresh copy of each cross_sync file, and compare hashes with the existing file. If this test fails, run `nox -s generate_sync` to update the sync files. """ - path = artifact.file_path - new_render = artifact.render() - found_render = open(path).read() + path = sync_file.output_path + new_render = sync_file.render(with_black=True, save_to_disk=False) + found_render = CrossSyncOutputFile( + output_path="", + ast_tree=ast.parse(open(path).read()), + header=sync_file.header + ).render(with_black=True, save_to_disk=False) # compare by content diff = unified_diff(found_render.splitlines(), new_render.splitlines(), lineterm="") diff_str = "\n".join(diff) @@ -49,3 +57,28 @@ def test_sync_up_to_date(artifact): new_hash = hashlib.md5(new_render.encode()).hexdigest() found_hash = hashlib.md5(found_render.encode()).hexdigest() assert new_hash == found_hash, f"md5 mismatch for {path}" + +@pytest.mark.parametrize( + "sync_file", sync_files, ids=lambda f: f.output_path +) +def test_verify_headers(sync_file): + license_regex = r""" + \#\ Copyright\ \d{4}\ Google\ LLC\n + \#\n + \#\ Licensed\ under\ the\ Apache\ License,\ Version\ 2\.0\ \(the\ \"License\"\);\n + \#\ you\ may\ not\ use\ this\ file\ except\ in\ compliance\ with\ the\ License\.\n + \#\ You\ may\ obtain\ a\ copy\ of\ the\ License\ at\ + \#\n + \#\s+http:\/\/www\.apache\.org\/licenses\/LICENSE-2\.0\n + \#\n + \#\ Unless\ required\ by\ applicable\ law\ or\ agreed\ to\ in\ writing,\ software\n + \#\ distributed\ under\ the\ License\ is\ distributed\ on\ an\ \"AS\ IS\"\ BASIS,\n + \#\ WITHOUT\ WARRANTIES\ OR\ CONDITIONS\ OF\ ANY\ KIND,\ either\ express\ or\ implied\.\n + \#\ See\ the\ License\ for\ the\ specific\ language\ governing\ permissions\ and\n + \#\ limitations\ under\ the\ License\. + """ + pattern = re.compile(license_regex, re.VERBOSE) + + with open(sync_file.output_path, "r") as f: + content = f.read() + assert pattern.search(content), "Missing license header" From c16ea3015fd7875bdc5837c60a6fbf57fd7447fe Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 16:59:32 -0700 Subject: [PATCH 284/360] fix annotations --- google/cloud/bigtable/data/_async/client.py | 1 + google/cloud/bigtable/data/_async/mutations_batcher.py | 4 ++-- tests/unit/data/_async/test__read_rows.py | 1 + tests/unit/data/_async/test_client.py | 2 ++ tests/unit/data/_async/test_read_rows_acceptance.py | 1 + 5 files changed, 7 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 85f24ef7e..5708f9db1 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -93,6 +93,7 @@ ) else: + from typing import Iterable # noqa: F401 from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport as PooledTransportType # type: ignore from google.cloud.bigtable.data._sync.mutations_batcher import ( # noqa: F401 MutationsBatcher, diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 9558b85a2..2aa8e47ac 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -14,7 +14,7 @@ # from __future__ import annotations -from typing import Sequence, TYPE_CHECKING +from typing import Sequence, TYPE_CHECKING, cast import atexit import warnings from collections import deque @@ -309,7 +309,7 @@ async def append(self, mutation_entry: RowMutationEntry): # TODO: return a future to track completion of this entry if self._closed.is_set(): raise RuntimeError("Cannot append to closed MutationsBatcher") - if isinstance(mutation_entry, Mutation): # type: ignore + if isinstance(cast(Mutation, mutation_entry), Mutation): raise ValueError( f"invalid mutation type: {type(mutation_entry).__name__}. Only RowMutationEntry objects are supported by batcher" ) diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index c41914e84..360d82380 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -1,3 +1,4 @@ +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index bffd4240c..3b5c0fe8e 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1496,6 +1496,7 @@ async def _make_gapic_stream( ): from google.cloud.bigtable_v2 import ReadRowsResponse + @CrossSync.export_sync class mock_stream: def __init__(self, chunk_list, sleep_time): self.chunk_list = chunk_list @@ -3166,6 +3167,7 @@ def _make_client(self, *args, **kwargs): @CrossSync.convert def _make_gapic_stream(self, sample_list: list["ExecuteQueryResponse" | Exception]): + @CrossSync.export_sync class MockStream: def __init__(self, sample_list): self.sample_list = sample_list diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index bf5af9786..2b52b3b01 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -137,6 +137,7 @@ async def test_read_rows_scenario(self, test_case: ReadRowsTest): async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): from google.cloud.bigtable_v2 import ReadRowsResponse + @CrossSync.export_sync class mock_stream: def __init__(self, chunk_list): self.chunk_list = chunk_list From 37c322603d2acf9e7d54d439f846dcb10ea3b7be Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 17:00:20 -0700 Subject: [PATCH 285/360] add warning to generated header --- .cross_sync/generate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index 618dfbc6c..86d515765 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -35,6 +35,7 @@ def extract_header_comments(file_path) -> str: header.append(line) else: break + header.append("\n# This file is automatically generated by CrossSync. Do not edit manually.\n\n") return "".join(header) @@ -102,5 +103,5 @@ def save_artifacts(artifacts: Sequence[CrossSyncOutputFile]): search_root = sys.argv[1] outputs = convert_files_in_dir(search_root) - print(f"Generated {len(outputs)} artifacts: {[a.file_path for a in outputs]}") + print(f"Generated {len(outputs)} artifacts: {[a.output_path for a in outputs]}") save_artifacts(outputs) From 54ee3bec7a3e15f0720b1090159221c478b97475 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 17:04:32 -0700 Subject: [PATCH 286/360] updated README --- .cross_sync/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.cross_sync/README.md b/.cross_sync/README.md index 0d43c1bb4..e02497a80 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -47,10 +47,10 @@ Additionally, CrossSync provides method implementations that work equivalently i CrossSync provides a set of annotations to mark up async classes, to guide the generation of sync code. - `@CrossSync.export_sync` - - marks classes for conversion, along with an output file path + - marks classes for conversion. Unmarked classes will be droppd - if add_mapping is included, the async and sync classes can be accessed using a shared CrossSync.X alias - `@CrossSync.convert` - - marks async functions for conversion + - marks async functions for conversion. Unmarked methods will be copied as-is - `@CrossSync.drop_method` - marks functions that should not be included in sync output - `@CrossSync.pytest` From d44b829735a379ddbb182e747b213c91c7736122 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 17:28:40 -0700 Subject: [PATCH 287/360] made drop_method into drop, with support for classes --- .cross_sync/README.md | 4 ++-- .cross_sync/transformers.py | 12 +++++------- .../data/_sync/cross_sync/_decorators.py | 6 +++--- .../data/_sync/cross_sync/cross_sync.py | 4 ++-- .../test_cases/cross_sync_files.yaml | 18 ++++++++++++++++-- .../data/_sync/test_cross_sync_decorators.py | 6 +++--- 6 files changed, 31 insertions(+), 19 deletions(-) diff --git a/.cross_sync/README.md b/.cross_sync/README.md index e02497a80..a2f193985 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -51,8 +51,8 @@ CrossSync provides a set of annotations to mark up async classes, to guide the g - if add_mapping is included, the async and sync classes can be accessed using a shared CrossSync.X alias - `@CrossSync.convert` - marks async functions for conversion. Unmarked methods will be copied as-is -- `@CrossSync.drop_method` - - marks functions that should not be included in sync output +- `@CrossSync.drop` + - marks functions or classes that should not be included in sync output - `@CrossSync.pytest` - marks test functions. Test functions automatically have all async keywords stripped (i.e., rm_aio is unneeded) - `CrossSync.add_mapping` diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index b4ffe7ba6..a780f2bb9 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -240,18 +240,16 @@ def visit_ClassDef(self, node): Called for each class in file. If class has a CrossSync decorator, it will be transformed according to the decorator arguments. Otherwise, class is returned unchanged """ - for decorator in node.decorator_list: + orig_decorators = node.decorator_list + for decorator in orig_decorators: try: handler = AstDecorator.get_for_node(decorator) - if isinstance(handler, ExportSync): - # transformation is handled in sync_ast_transform method of the decorator - after_export = handler.sync_ast_transform(node, globals()) - return self.generic_visit(after_export) + # transformation is handled in sync_ast_transform method of the decorator + node = handler.sync_ast_transform(node, globals()) except ValueError: # not cross_sync decorator continue - # cross_sync decorator not found. Drop from sync version - return None + return self.generic_visit(node) if node else None def visit_If(self, node): """ diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index 965765161..8f90efce0 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -342,14 +342,14 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): return super().sync_ast_transform(converted, transformers_globals) -class DropMethod(AstDecorator): +class Drop(AstDecorator): """ - Method decorator to drop async methods from the sync output + Method decorator to drop methods or classes from the sync output """ def sync_ast_transform(self, wrapped_node, transformers_globals): """ - Drop method from sync output + Drop from sync output """ return None diff --git a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py index dceff5a62..2433562cf 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py @@ -62,7 +62,7 @@ async def async_func(self, arg: int) -> int: from ._decorators import ( ExportSync, Convert, - DropMethod, + Drop, Pytest, PytestFixture, ) @@ -99,7 +99,7 @@ class CrossSync(metaclass=MappingMeta): # decorators export_sync = ExportSync.decorator # decorate classes to convert convert = Convert.decorator # decorate methods to convert from async to sync - drop_method = DropMethod.decorator # decorate methods to remove from sync version + drop = Drop.decorator # decorate methods to remove from sync version pytest = Pytest.decorator # decorate test methods to run with pytest-asyncio pytest_fixture = ( PytestFixture.decorator diff --git a/tests/system/cross_sync/test_cases/cross_sync_files.yaml b/tests/system/cross_sync/test_cases/cross_sync_files.yaml index 9c8627657..3eca767ec 100644 --- a/tests/system/cross_sync/test_cases/cross_sync_files.yaml +++ b/tests/system/cross_sync/test_cases/cross_sync_files.yaml @@ -87,7 +87,7 @@ tests: async with self.base.connection(): return await self.base.my_method() - @CrossSync.drop_method + @CrossSync.drop async def async_only_method(self): await self.async_operation() @@ -120,6 +120,7 @@ tests: __CROSS_SYNC_OUTPUT__ = "out.path" @CrossSync.export_sync(sync_name="MyClass") class MyAsyncClass: + @CrossSync.drop class NestedAsyncClass: async def nested_method(self, base: AsyncBase): pass @@ -263,7 +264,7 @@ tests: def keep_method(self): pass - @CrossSync.drop_method + @CrossSync.drop async def async_only_method(self): await self.async_operation() transformers: [CrossSyncFileProcessor] @@ -271,6 +272,19 @@ tests: def keep_method(self): pass + - description: "Drop class from sync version" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.drop + class DropMe: + pass + class Keeper: + pass + transformers: [CrossSyncFileProcessor] + after: | + class Keeper: + pass + - description: "Convert.pytest" before: | __CROSS_SYNC_OUTPUT__ = "out.path" diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_sync/test_cross_sync_decorators.py index 727d5a599..bc9a92917 100644 --- a/tests/unit/data/_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_sync/test_cross_sync_decorators.py @@ -20,7 +20,7 @@ from google.cloud.bigtable.data._sync.cross_sync._decorators import ( ExportSync, Convert, - DropMethod, + Drop, Pytest, PytestFixture, ) @@ -412,9 +412,9 @@ def test_sync_ast_transform_add_docstring_format( assert result.body[0].value.value == expected -class TestDropMethodDecorator: +class TestDropDecorator: def _get_class(self): - return DropMethod + return Drop def test_decorator_functionality(self): """ From 61490e5c53ef5b77a907b164c9e926578dc1f839 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 17:30:14 -0700 Subject: [PATCH 288/360] renamed export into convert_class --- .cross_sync/README.md | 4 +- .cross_sync/transformers.py | 2 +- .../data/_sync/cross_sync/_decorators.py | 6 +- .../data/_sync/cross_sync/cross_sync.py | 4 +- .../test_cases/cross_sync_files.yaml | 74 +++++++++++++------ .../data/_sync/test_cross_sync_decorators.py | 8 +- 6 files changed, 63 insertions(+), 35 deletions(-) diff --git a/.cross_sync/README.md b/.cross_sync/README.md index a2f193985..5f7a62581 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -46,8 +46,8 @@ Additionally, CrossSync provides method implementations that work equivalently i CrossSync provides a set of annotations to mark up async classes, to guide the generation of sync code. -- `@CrossSync.export_sync` - - marks classes for conversion. Unmarked classes will be droppd +- `@CrossSync.convert_sync` + - marks classes for conversion. Unmarked classes will be copied as-is - if add_mapping is included, the async and sync classes can be accessed using a shared CrossSync.X alias - `@CrossSync.convert` - marks async functions for conversion. Unmarked methods will be copied as-is diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index a780f2bb9..e47eb2f31 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -32,7 +32,7 @@ import sys # add cross_sync to path sys.path.append("google/cloud/bigtable/data/_sync/cross_sync") -from _decorators import AstDecorator, ExportSync +from _decorators import AstDecorator, ConvertClass class SymbolReplacer(ast.NodeTransformer): diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index 8f90efce0..4e79331dd 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -182,9 +182,9 @@ def _convert_ast_to_py(cls, ast_node: ast.expr | None) -> Any: raise ValueError(f"Unsupported type {type(ast_node)}") -class ExportSync(AstDecorator): +class ConvertClass(AstDecorator): """ - Class decorator for marking async classes to be converted to sync classes + Class decorator for guiding generation of sync classes Args: sync_name: use a new name for the sync class @@ -290,7 +290,7 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): return wrapped_node -class Convert(ExportSync): +class Convert(ConvertClass): """ Method decorator to mark async methods to be converted to sync methods diff --git a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py index 2433562cf..1f1ee111a 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py @@ -60,7 +60,7 @@ async def async_func(self, arg: int) -> int: import threading import time from ._decorators import ( - ExportSync, + ConvertClass, Convert, Drop, Pytest, @@ -97,7 +97,7 @@ class CrossSync(metaclass=MappingMeta): Generator: TypeAlias = AsyncGenerator # decorators - export_sync = ExportSync.decorator # decorate classes to convert + convert_class = ConvertClass.decorator # decorate classes to convert convert = Convert.decorator # decorate methods to convert from async to sync drop = Drop.decorator # decorate methods to remove from sync version pytest = Pytest.decorator # decorate test methods to run with pytest-asyncio diff --git a/tests/system/cross_sync/test_cases/cross_sync_files.yaml b/tests/system/cross_sync/test_cases/cross_sync_files.yaml index 3eca767ec..f6c439d0f 100644 --- a/tests/system/cross_sync/test_cases/cross_sync_files.yaml +++ b/tests/system/cross_sync/test_cases/cross_sync_files.yaml @@ -9,10 +9,10 @@ tests: - name: CrossSyncFileProcessor after: null - - description: "CrossSync.export_sync with default sync_name" + - description: "CrossSync.convert_class with default sync_name" before: | __CROSS_SYNC_OUTPUT__ = "out.path" - @CrossSync.export_sync + @CrossSync.convert_class class MyClass: async def my_method(self): pass @@ -25,10 +25,10 @@ tests: async def my_method(self): pass - - description: "CrossSync.export_sync with custom sync_name" + - description: "CrossSync.convert_class with custom sync_name" before: | __CROSS_SYNC_OUTPUT__ = "out.path" - @CrossSync.export_sync(sync_name="MyClass") + @CrossSync.convert_class(sync_name="MyClass") class MyAsyncClass: async def my_method(self): pass @@ -41,10 +41,10 @@ tests: async def my_method(self): pass - - description: "CrossSync.export_sync with replace_symbols" + - description: "CrossSync.convert_class with replace_symbols" before: | __CROSS_SYNC_OUTPUT__ = "out.path" - @CrossSync.export_sync( + @CrossSync.convert_class( sync_name="MyClass", replace_symbols={"AsyncBase": "SyncBase", "ParentA": "ParentB"} ) @@ -60,10 +60,10 @@ tests: def __init__(self, base: SyncBase): self.base = base - - description: "CrossSync.export_sync with docstring formatting" + - description: "CrossSync.convert_class with docstring formatting" before: | __CROSS_SYNC_OUTPUT__ = "out.path" - @CrossSync.export_sync( + @CrossSync.convert_class( sync_name="MyClass", docstring_format_vars={"type": ("async", "sync")} ) @@ -76,10 +76,10 @@ tests: class MyClass: """This is a sync class.""" - - description: "CrossSync.export_sync with multiple decorators and methods" + - description: "CrossSync.convert_class with multiple decorators and methods" before: | __CROSS_SYNC_OUTPUT__ = "out.path" - @CrossSync.export_sync(sync_name="MyClass") + @CrossSync.convert_class(sync_name="MyClass") @some_other_decorator class MyAsyncClass: @CrossSync.convert @@ -115,10 +115,10 @@ tests: def fixture(self): pass - - description: "CrossSync.export_sync with nested classes drop" + - description: "CrossSync.convert_class with nested classes drop" before: | __CROSS_SYNC_OUTPUT__ = "out.path" - @CrossSync.export_sync(sync_name="MyClass") + @CrossSync.convert_class(sync_name="MyClass") class MyAsyncClass: @CrossSync.drop class NestedAsyncClass: @@ -138,12 +138,12 @@ tests: nested = self.NestedAsyncClass() nested.nested_method() - - description: "CrossSync.export_sync with nested classes" + - description: "CrossSync.convert_class with nested classes explicit" before: | __CROSS_SYNC_OUTPUT__ = "out.path" - @CrossSync.export_sync(sync_name="MyClass", replace_symbols={"AsyncBase": "SyncBase"}) + @CrossSync.convert_class(sync_name="MyClass", replace_symbols={"AsyncBase": "SyncBase"}) class MyAsyncClass: - @CrossSync.export_sync + @CrossSync.convert_class class NestedClass: async def nested_method(self, base: AsyncBase): pass @@ -166,10 +166,38 @@ tests: nested = self.NestedAsyncClass() nested.nested_method() - - description: "CrossSync.export_sync with add_mapping" + - description: "CrossSync.convert_class with nested classes implicit" before: | __CROSS_SYNC_OUTPUT__ = "out.path" - @CrossSync.export_sync( + @CrossSync.convert_class(sync_name="MyClass", replace_symbols={"AsyncBase": "SyncBase"}) + class MyAsyncClass: + + class NestedClass: + async def nested_method(self, base: AsyncBase): + pass + + @CrossSync.convert + async def use_nested(self): + nested = self.NestedAsyncClass() + CrossSync.rm_aio(await nested.nested_method()) + transformers: + - name: CrossSyncFileProcessor + after: | + class MyClass: + + class NestedClass: + + async def nested_method(self, base: SyncBase): + pass + + def use_nested(self): + nested = self.NestedAsyncClass() + nested.nested_method() + + - description: "CrossSync.convert_class with add_mapping" + before: | + __CROSS_SYNC_OUTPUT__ = "out.path" + @CrossSync.convert_class( sync_name="MyClass", add_mapping_for_name="MyClass" ) @@ -186,10 +214,10 @@ tests: async def my_method(self): pass - - description: "CrossSync.export_sync with rm_aio" + - description: "CrossSync.convert_class with rm_aio" before: | __CROSS_SYNC_OUTPUT__ = "out.path" - @CrossSync.export_sync(rm_aio=True) + @CrossSync.convert_class(rm_aio=True) class MyClass: async def my_method(self): async for item in self.items: @@ -202,10 +230,10 @@ tests: for item in self.items: self.process(item) - - description: "CrossSync.export_sync with CrossSync calls" + - description: "CrossSync.convert_class with CrossSync calls" before: | __CROSS_SYNC_OUTPUT__ = "out.path" - @CrossSync.export_sync(sync_name="MyClass") + @CrossSync.convert_class(sync_name="MyClass") class MyAsyncClass: @CrossSync.convert async def my_method(self): @@ -365,7 +393,7 @@ tests: - description: "Convert method with multiple stacked decorators in class" before: | __CROSS_SYNC_OUTPUT__ = "out.path" - @CrossSync.export_sync + @CrossSync.convert_class class MyClass: @CrossSync.convert(sync_name="sync_multi_decorated") @CrossSync.pytest @@ -427,7 +455,7 @@ tests: before: | __CROSS_SYNC_OUTPUT__ = "out.path" CrossSync.sleep(1) - @CrossSync.export_sync + @CrossSync.convert_class class MyClass: event = CrossSync.Event() def my_method(self): diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_sync/test_cross_sync_decorators.py index bc9a92917..febb62267 100644 --- a/tests/unit/data/_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_sync/test_cross_sync_decorators.py @@ -18,7 +18,7 @@ from unittest import mock from google.cloud.bigtable.data._sync.cross_sync.cross_sync import CrossSync from google.cloud.bigtable.data._sync.cross_sync._decorators import ( - ExportSync, + ConvertClass, Convert, Drop, Pytest, @@ -37,9 +37,9 @@ def globals_mock(): return global_dict -class TestExportSyncDecorator: +class TestConvertClassDecorator: def _get_class(self): - return ExportSync + return ConvertClass def test_ctor_defaults(self): """ @@ -116,7 +116,7 @@ def test_class_decorator_docstring_update(self, docstring, format_vars, expected of the class being decorated """ - @ExportSync.decorator(sync_name="s", docstring_format_vars=format_vars) + @ConvertClass.decorator(sync_name="s", docstring_format_vars=format_vars) class Class: __doc__ = docstring From b8b2181b284fa5f6b6551c4839d020c3b45b3ea8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 17:36:18 -0700 Subject: [PATCH 289/360] use new decorator names --- .../bigtable/data/_async/_mutate_rows.py | 2 +- .../cloud/bigtable/data/_async/_read_rows.py | 2 +- google/cloud/bigtable/data/_async/client.py | 4 +-- .../bigtable/data/_async/mutations_batcher.py | 4 +-- .../_async/execute_query_iterator.py | 2 +- tests/system/data/test_system_async.py | 6 ++-- tests/unit/data/_async/test__mutate_rows.py | 2 +- tests/unit/data/_async/test__read_rows.py | 2 +- tests/unit/data/_async/test_client.py | 28 +++++++++---------- .../data/_async/test_mutations_batcher.py | 4 +-- .../data/_async/test_read_rows_acceptance.py | 2 +- .../_async/test_query_iterator.py | 4 +-- 12 files changed, 31 insertions(+), 31 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index f0a7c050b..ce908ee26 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -47,7 +47,7 @@ __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync._mutate_rows" -@CrossSync.export_sync("_MutateRowsOperation") +@CrossSync.convert_class("_MutateRowsOperation") class _MutateRowsOperationAsync: """ MutateRowsOperation manages the logic of sending a set of row mutations, diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index e8219f75c..a68660daa 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -45,7 +45,7 @@ __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync._read_rows" -@CrossSync.export_sync("_ReadRowsOperation") +@CrossSync.convert_class("_ReadRowsOperation") class _ReadRowsOperationAsync: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 85f24ef7e..72b2792c0 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -109,7 +109,7 @@ __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync.client" -@CrossSync.export_sync( +@CrossSync.convert_class( sync_name="BigtableDataClient", add_mapping_for_name="DataClient", ) @@ -638,7 +638,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): CrossSync.rm_aio(await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb)) -@CrossSync.export_sync(sync_name="Table", add_mapping_for_name="Table") +@CrossSync.convert_class(sync_name="Table", add_mapping_for_name="Table") class TableAsync: """ Main Data API surface diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 2bd3f7b35..603f0886a 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -45,7 +45,7 @@ __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync.mutations_batcher" -@CrossSync.export_sync(sync_name="_FlowControl", add_mapping_for_name="_FlowControl") +@CrossSync.convert_class(sync_name="_FlowControl", add_mapping_for_name="_FlowControl") class _FlowControlAsync: """ Manages flow control for batched mutations. Mutations are registered against @@ -176,7 +176,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] yield mutations[start_idx:end_idx] -@CrossSync.export_sync( +@CrossSync.convert_class( sync_name="MutationsBatcher", add_mapping_for_name="MutationsBatcher" ) class MutationsBatcherAsync: diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index 1e780ea75..367a2925e 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -55,7 +55,7 @@ ) -@CrossSync.export_sync(sync_name="ExecuteQueryIterator") +@CrossSync.convert_class(sync_name="ExecuteQueryIterator") class ExecuteQueryIteratorAsync: @CrossSync.convert( docstring_format_vars={ diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 6d068c152..ad9c973d2 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -30,7 +30,7 @@ __CROSS_SYNC_OUTPUT__ = "tests.system.data.test_system" -@CrossSync.export_sync( +@CrossSync.convert_class( sync_name="TempRowBuilder", add_mapping_for_name="TempRowBuilder", ) @@ -80,7 +80,7 @@ async def delete_rows(self): CrossSync.rm_aio(await self.table.client._gapic_client.mutate_rows(request)) -@CrossSync.export_sync(sync_name="TestSystem") +@CrossSync.convert_class(sync_name="TestSystem") class TestSystemAsync: @CrossSync.convert @CrossSync.pytest_fixture(scope="session") @@ -102,7 +102,7 @@ async def table(self, client, table_id, instance_id): ) as table: yield table - @CrossSync.drop_method + @CrossSync.drop @pytest.fixture(scope="session") def event_loop(self): loop = asyncio.get_event_loop() diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 35184652c..13a30cb37 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -30,7 +30,7 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test__mutate_rows" -@CrossSync.export_sync("TestMutateRowsOperation") +@CrossSync.convert_class("TestMutateRowsOperation") class TestMutateRowsOperation: def _target_class(self): return CrossSync._MutateRowsOperation diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index c41914e84..4d63a0fb1 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -25,7 +25,7 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test__read_rows" -@CrossSync.export_sync( +@CrossSync.convert_class( sync_name="TestReadRowsOperation", ) class TestReadRowsOperationAsync: diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index bffd4240c..f1999f802 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -50,7 +50,7 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_client" -@CrossSync.export_sync( +@CrossSync.convert_class( sync_name="TestBigtableDataClient", add_mapping_for_name="TestBigtableDataClient", ) @@ -270,7 +270,7 @@ async def test_channel_pool_replace(self): assert client.transport._grpc_channel._pool[i] != start_pool[i] await client.close() - @CrossSync.drop_method + @CrossSync.drop @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test__start_background_channel_refresh_sync(self): # should raise RuntimeError if called in a sync context @@ -314,7 +314,7 @@ async def test__start_background_channel_refresh(self, pool_size): ping_and_warm.assert_any_call(channel) await client.close() - @CrossSync.drop_method + @CrossSync.drop @CrossSync.pytest @pytest.mark.skipif( sys.version_info < (3, 8), reason="Task.name requires python3.8 or higher" @@ -1095,7 +1095,7 @@ async def test_context_manager(self): # actually close the client await true_close - @CrossSync.drop_method + @CrossSync.drop def test_client_ctor_sync(self): # initializing client in a sync context should raise RuntimeError @@ -1111,7 +1111,7 @@ def test_client_ctor_sync(self): assert client._channel_refresh_tasks == [] -@CrossSync.export_sync("TestTable", add_mapping_for_name="TestTable") +@CrossSync.convert_class("TestTable", add_mapping_for_name="TestTable") class TestTableAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -1240,7 +1240,7 @@ async def test_table_ctor_invalid_timeout_values(self): assert "operation_timeout must be greater than 0" in str(e.value) await client.close() - @CrossSync.drop_method + @CrossSync.drop def test_table_ctor_sync(self): # initializing client in a sync context should raise RuntimeError client = mock.Mock() @@ -1423,7 +1423,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ assert "app_profile_id=" not in goog_metadata -@CrossSync.export_sync( +@CrossSync.convert_class( "TestReadRows", add_mapping_for_name="TestReadRows", ) @@ -1936,7 +1936,7 @@ async def test_row_exists(self, return_value, expected_result): assert query.filter._to_dict() == expected_filter -@CrossSync.export_sync("TestReadRowsSharded") +@CrossSync.convert_class("TestReadRowsSharded") class TestReadRowsShardedAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -2156,7 +2156,7 @@ async def mock_call(*args, **kwargs): ) -@CrossSync.export_sync("TestSampleRowKeys") +@CrossSync.convert_class("TestSampleRowKeys") class TestSampleRowKeysAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -2310,7 +2310,7 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio await table.sample_row_keys() -@CrossSync.export_sync("TestMutateRow") +@CrossSync.convert_class("TestMutateRow") class TestMutateRowAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -2487,7 +2487,7 @@ async def test_mutate_row_no_mutations(self, mutations): assert e.value.args[0] == "No mutations provided" -@CrossSync.export_sync("TestBulkMutateRows") +@CrossSync.convert_class("TestBulkMutateRows") class TestBulkMutateRowsAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -2870,7 +2870,7 @@ async def test_bulk_mutate_error_recovery(self): await table.bulk_mutate_rows(entries, operation_timeout=1000) -@CrossSync.export_sync("TestCheckAndMutateRow") +@CrossSync.convert_class("TestCheckAndMutateRow") class TestCheckAndMutateRowAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -3023,7 +3023,7 @@ async def test_check_and_mutate_mutations_parsing(self): ) -@CrossSync.export_sync("TestReadModifyWriteRow") +@CrossSync.convert_class("TestReadModifyWriteRow") class TestReadModifyWriteRowAsync: @CrossSync.convert def _make_client(self, *args, **kwargs): @@ -3155,7 +3155,7 @@ async def test_read_modify_write_row_building(self): constructor_mock.assert_called_once_with(mock_response.row) -@CrossSync.export_sync("TestExecuteQuery") +@CrossSync.convert_class("TestExecuteQuery") class TestExecuteQueryAsync: TABLE_NAME = "TABLE_NAME" INSTANCE_NAME = "INSTANCE_NAME" diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index cdfdc5ab7..2c0c103a1 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -26,7 +26,7 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_mutations_batcher" -@CrossSync.export_sync(sync_name="Test_FlowControl") +@CrossSync.convert_class(sync_name="Test_FlowControl") class Test_FlowControl: @staticmethod @CrossSync.convert @@ -296,7 +296,7 @@ async def test_add_to_flow_oversize(self): assert len(count_results) == 1 -@CrossSync.export_sync(sync_name="TestMutationsBatcher") +@CrossSync.convert_class(sync_name="TestMutationsBatcher") class TestMutationsBatcherAsync: @CrossSync.convert def _get_target_class(self): diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index bf5af9786..0f275ca8c 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -33,7 +33,7 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_read_rows_acceptance" -@CrossSync.export_sync( +@CrossSync.convert_class( sync_name="TestReadRowsAcceptance", ) class TestReadRowsAcceptanceAsync: diff --git a/tests/unit/data/execute_query/_async/test_query_iterator.py b/tests/unit/data/execute_query/_async/test_query_iterator.py index d24c47466..9cef2dfb3 100644 --- a/tests/unit/data/execute_query/_async/test_query_iterator.py +++ b/tests/unit/data/execute_query/_async/test_query_iterator.py @@ -30,7 +30,7 @@ __CROSS_SYNC_OUTPUT__ = "tests.unit.data.execute_query._sync.test_query_iterator" -@CrossSync.export_sync(sync_name="MockIterator") +@CrossSync.convert_class(sync_name="MockIterator") class MockIterator: def __init__(self, values, delay=None): self._values = values @@ -52,7 +52,7 @@ async def __anext__(self): return value -@CrossSync.export_sync(sync_name="TestQueryIterator") +@CrossSync.convert_class(sync_name="TestQueryIterator") class TestQueryIteratorAsync: @staticmethod def _target_class(): From c6fcd6999d92ff6a68db266b5addbb4537a42d84 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Sep 2024 17:38:46 -0700 Subject: [PATCH 290/360] regenerate files --- .../cloud/bigtable/data/_sync/_mutate_rows.py | 7 +-- .../cloud/bigtable/data/_sync/_read_rows.py | 8 +-- google/cloud/bigtable/data/_sync/client.py | 9 ++-- .../bigtable/data/_sync/mutations_batcher.py | 11 ++-- .../_sync/execute_query_iterator.py | 7 ++- tests/system/data/test_system.py | 8 +-- tests/unit/data/_sync/test__mutate_rows.py | 7 +-- tests/unit/data/_sync/test__read_rows.py | 7 +-- tests/unit/data/_sync/test_client.py | 53 +++++-------------- .../unit/data/_sync/test_mutations_batcher.py | 45 +++++----------- .../data/_sync/test_read_rows_acceptance.py | 8 +-- .../_sync/test_query_iterator.py | 5 +- 12 files changed, 66 insertions(+), 109 deletions(-) diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync/_mutate_rows.py index ff94691a7..232ddced9 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync/_mutate_rows.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + # This file is automatically generated by CrossSync. Do not edit manually. + from __future__ import annotations from typing import Sequence, TYPE_CHECKING import functools @@ -32,10 +34,9 @@ BigtableClient as GapicClientType, ) from google.cloud.bigtable.data._sync.client import Table as TableType -__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync._mutate_rows" -class MutateRowsOperation: +class _MutateRowsOperation: """ MutateRowsOperation manages the logic of sending a set of row mutations, and retrying on failed entries. It manages this using the _run_attempt diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync/_read_rows.py index 373ec2884..05254279e 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync/_read_rows.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # + + # This file is automatically generated by CrossSync. Do not edit manually. + from __future__ import annotations from typing import Sequence, TYPE_CHECKING from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB @@ -33,10 +36,9 @@ if TYPE_CHECKING: from google.cloud.bigtable.data._sync.client import Table as TableType -__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync._read_rows" -class ReadRowsOperation: +class _ReadRowsOperation: """ ReadRowsOperation handles the logic of merging chunks from a ReadRowsResponse stream into a stream of Row objects. diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync/client.py index cc4c03e5f..d6c00c91e 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync/client.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # + + # This file is automatically generated by CrossSync. Do not edit manually. + from __future__ import annotations from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING import time @@ -64,6 +67,7 @@ from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter from google.cloud.bigtable.data.row_filters import RowFilterChain from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from typing import Iterable from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( PooledBigtableGrpcTransport as PooledTransportType, ) @@ -75,7 +79,6 @@ if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery -__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync.client" @CrossSync._Sync_Impl.add_mapping_decorator("DataClient") @@ -175,8 +178,6 @@ def __init__( def _client_version() -> str: """Helper function to return the client version string for this client""" version_str = f"{google.cloud.bigtable.__version__}-data" - if CrossSync._Sync_Impl.is_async: - version_str += "-async" return version_str def _start_background_channel_refresh(self) -> None: diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync/mutations_batcher.py index 924418201..800774d36 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync/mutations_batcher.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # + # This file is automatically generated by CrossSync. Do not edit manually. + from __future__ import annotations -from typing import Sequence, TYPE_CHECKING +from typing import Sequence, TYPE_CHECKING, cast import atexit import warnings from collections import deque @@ -32,7 +34,6 @@ if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data._sync.client import Table as TableType -__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync.mutations_batcher" @CrossSync._Sync_Impl.add_mapping_decorator("_FlowControl") @@ -260,7 +261,7 @@ def append(self, mutation_entry: RowMutationEntry): ValueError: if an invalid mutation type is added""" if self._closed.is_set(): raise RuntimeError("Cannot append to closed MutationsBatcher") - if isinstance(mutation_entry, Mutation): + if isinstance(cast(Mutation, mutation_entry), Mutation): raise ValueError( f"invalid mutation type: {type(mutation_entry).__name__}. Only RowMutationEntry objects are supported by batcher" ) @@ -438,8 +439,6 @@ def _wait_for_batch_results( return [] exceptions: list[Exception] = [] for task in tasks: - if CrossSync._Sync_Impl.is_async: - task try: exc_list = task.result() if exc_list: diff --git a/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py index 7523e11d6..974ee3964 100644 --- a/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + + # This file is automatically generated by CrossSync. Do not edit manually. + from __future__ import annotations from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING from google.api_core import retry as retries @@ -35,9 +37,6 @@ if TYPE_CHECKING: from google.cloud.bigtable.data import BigtableDataClient as DataClientType -__CROSS_SYNC_OUTPUT__ = ( - "google.cloud.bigtable.data.execute_query._sync.execute_query_iterator" -) class ExecuteQueryIterator: diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py index 0c77623cf..381256020 100644 --- a/tests/system/data/test_system.py +++ b/tests/system/data/test_system.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + + # This file is automatically generated by CrossSync. Do not edit manually. + import pytest import uuid import os @@ -23,8 +25,6 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync from . import TEST_FAMILY, TEST_FAMILY_2 -__CROSS_SYNC_OUTPUT__ = "tests.system.data.test_system" - @CrossSync._Sync_Impl.add_mapping_decorator("TempRowBuilder") class TempRowBuilder: diff --git a/tests/unit/data/_sync/test__mutate_rows.py b/tests/unit/data/_sync/test__mutate_rows.py index 33510630c..59c7074a8 100644 --- a/tests/unit/data/_sync/test__mutate_rows.py +++ b/tests/unit/data/_sync/test__mutate_rows.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + + # This file is automatically generated by CrossSync. Do not edit manually. + import pytest from google.cloud.bigtable_v2.types import MutateRowsResponse from google.rpc import status_pb2 @@ -24,7 +26,6 @@ from unittest import mock except ImportError: import mock -__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test__mutate_rows" class TestMutateRowsOperation: diff --git a/tests/unit/data/_sync/test__read_rows.py b/tests/unit/data/_sync/test__read_rows.py index 84dd95d96..dc6c24f5b 100644 --- a/tests/unit/data/_sync/test__read_rows.py +++ b/tests/unit/data/_sync/test__read_rows.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + + # This file is automatically generated by CrossSync. Do not edit manually. + import pytest from google.cloud.bigtable.data._sync.cross_sync import CrossSync @@ -20,7 +22,6 @@ from unittest import mock except ImportError: import mock -__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test__read_rows" class TestReadRowsOperation: diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index d2936b714..49c052129 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + # This file is automatically generated by CrossSync. Do not edit manually. + from __future__ import annotations import grpc import asyncio import re import pytest +import mock from google.cloud.bigtable.data import mutations from google.auth.credentials import AnonymousCredentials from google.cloud.bigtable_v2.types import ReadRowsResponse @@ -30,15 +32,9 @@ from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse from google.cloud.bigtable.data._sync.cross_sync import CrossSync - -try: - from unittest import mock -except ImportError: - import mock from google.api_core import grpc_helpers -CrossSync.add_mapping("grpc_helpers", grpc_helpers) -__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_client" +CrossSync._Sync_Impl.add_mapping("grpc_helpers", grpc_helpers) @CrossSync._Sync_Impl.add_mapping_decorator("TestBigtableDataClient") @@ -154,10 +150,7 @@ def test_veneer_grpc_headers(self): + client_component + " gl-python\\/[0-9]+\\.[\\w.-]+ grpc\\/[0-9]+\\.[\\w.-]+" ) - if CrossSync._Sync_Impl.is_async: - patch = mock.patch("google.api_core.gapic_v1.method_async.wrap_method") - else: - patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") + patch = mock.patch("google.api_core.gapic_v1.method.wrap_method") with patch as gapic_mock: client = self._make_client(project="project-id") wrapped_call_list = gapic_mock.call_args_list @@ -232,9 +225,6 @@ def test_channel_pool_replace(self): replace_idx, grace=grace_period, new_channel=new_channel ) close.assert_called_once() - if CrossSync._Sync_Impl.is_async: - close.assert_called_once_with(grace=grace_period) - close.assert_awaited_once() assert client.transport._grpc_channel._pool[replace_idx] == new_channel for i in range(pool_size): if i != replace_idx: @@ -266,12 +256,7 @@ def test__start_background_channel_refresh(self, pool_size): client._start_background_channel_refresh() assert len(client._channel_refresh_tasks) == pool_size for task in client._channel_refresh_tasks: - if CrossSync._Sync_Impl.is_async: - assert isinstance(task, asyncio.Task) - else: - assert isinstance(task, concurrent.futures.Future) - if CrossSync._Sync_Impl.is_async: - asyncio.sleep(0.1) + assert isinstance(task, concurrent.futures.Future) assert ping_and_warm.call_count == pool_size for channel in client.transport._grpc_channel._pool: ping_and_warm.assert_any_call(channel) @@ -309,8 +294,6 @@ def test__ping_and_warm_instances(self): gather.assert_called_once() partial_list = gather.call_args.args[0] assert len(partial_list) == 4 - if CrossSync._Sync_Impl.is_async: - gather.assert_awaited_once() grpc_call_args = channel.unary_unary().call_args_list for idx, (_, kwargs) in enumerate(grpc_call_args): ( @@ -463,12 +446,9 @@ def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sle except asyncio.CancelledError: pass assert sleep.call_count == num_cycles - if CrossSync._Sync_Impl.is_async: - total_sleep = sum([call[0][0] for call in sleep.call_args_list]) - else: - total_sleep = sum( - [call[1]["timeout"] for call in sleep.call_args_list] - ) + total_sleep = sum( + [call[1]["timeout"] for call in sleep.call_args_list] + ) assert ( abs(total_sleep - expected_sleep) < 0.1 ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" @@ -919,8 +899,6 @@ def test_close(self): ) as close_mock: client.close() close_mock.assert_called_once() - if CrossSync._Sync_Impl.is_async: - close_mock.assert_awaited() for task in tasks_list: assert task.done() @@ -934,8 +912,6 @@ def test_close_with_timeout(self): ) as wait_for_mock: client.close(timeout=expected_timeout) wait_for_mock.assert_called_once() - if CrossSync._Sync_Impl.is_async: - wait_for_mock.assert_awaited() assert wait_for_mock.call_args[1]["timeout"] == expected_timeout client._channel_refresh_tasks = tasks client.close() @@ -952,8 +928,6 @@ def test_context_manager(self): assert client._active_instances == set() close_mock.assert_not_called() close_mock.assert_called_once() - if CrossSync._Sync_Impl.is_async: - close_mock.assert_awaited() true_close @@ -1126,10 +1100,7 @@ def test_customizable_retryable_errors( retry_fn = "retry_target" if is_stream: retry_fn += "_stream" - if CrossSync._Sync_Impl.is_async: - retry_fn = f"CrossSync.{retry_fn}" - else: - retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" + retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" with mock.patch( f"google.cloud.bigtable.data._sync.cross_sync.{retry_fn}" ) as retry_fn_mock: @@ -1271,6 +1242,7 @@ def _make_gapic_stream( ): from google.cloud.bigtable_v2 import ReadRowsResponse + @CrossSync._Sync_Impl.export_sync class mock_stream: def __init__(self, chunk_list, sleep_time): self.chunk_list = chunk_list @@ -2749,6 +2721,7 @@ def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) def _make_gapic_stream(self, sample_list: list["ExecuteQueryResponse" | Exception]): + @CrossSync._Sync_Impl.export_sync class MockStream: def __init__(self, sample_list): self.sample_list = sample_list diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py index 2064df07c..7b48b6682 100644 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + + # This file is automatically generated by CrossSync. Do not edit manually. + import pytest +import mock import asyncio import time import google.api_core.exceptions as core_exceptions @@ -22,12 +25,6 @@ from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data._sync.cross_sync import CrossSync -try: - from unittest import mock -except ImportError: - import mock -__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_mutations_batcher" - class Test_FlowControl: @staticmethod @@ -150,18 +147,11 @@ def task_routine(): lambda: instance._has_capacity(1, 1) ) - if CrossSync._Sync_Impl.is_async: - task = asyncio.create_task(task_routine()) - - def task_alive(): - return not task.done() - - else: - import threading + import threading - thread = threading.Thread(target=task_routine) - thread.start() - task_alive = thread.is_alive + thread = threading.Thread(target=task_routine) + thread.start() + task_alive = thread.is_alive CrossSync._Sync_Impl.sleep(0.05) assert task_alive() is True mutation = self._make_mutation(count=0, size=5) @@ -451,10 +441,7 @@ def test__start_flush_timer_w_empty_input(self, input_val): self._get_target_class(), "_schedule_flush" ) as flush_mock: with self._make_one() as instance: - if CrossSync._Sync_Impl.is_async: - (sleep_obj, sleep_method) = (asyncio, "wait_for") - else: - (sleep_obj, sleep_method) = (instance._closed, "wait") + (sleep_obj, sleep_method) = (instance._closed, "wait") with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: result = instance._timer_routine(input_val) assert sleep_mock.call_count == 0 @@ -470,10 +457,7 @@ def test__start_flush_timer_call_when_closed(self): with self._make_one() as instance: instance.close() flush_mock.reset_mock() - if CrossSync._Sync_Impl.is_async: - (sleep_obj, sleep_method) = (asyncio, "wait_for") - else: - (sleep_obj, sleep_method) = (instance._closed, "wait") + (sleep_obj, sleep_method) = (instance._closed, "wait") with mock.patch.object(sleep_obj, sleep_method) as sleep_mock: instance._timer_routine(10) assert sleep_mock.call_count == 0 @@ -500,8 +484,6 @@ def test__flush_timer(self, num_staged): self._get_target_class()._timer_routine( instance, expected_sleep ) - if CrossSync._Sync_Impl.is_async: - instance._flush_timer = CrossSync._Sync_Impl.Future() assert sleep_mock.call_count == loop_num + 1 sleep_kwargs = sleep_mock.call_args[1] assert sleep_kwargs["timeout"] == expected_sleep @@ -733,10 +715,7 @@ def test_flush_clears_job_list(self): assert instance._flush_jobs == set() new_job = instance._schedule_flush() assert instance._flush_jobs == {new_job} - if CrossSync._Sync_Impl.is_async: - new_job - else: - new_job.result() + new_job.result() assert instance._flush_jobs == set() @pytest.mark.parametrize( diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync/test_read_rows_acceptance.py index 9716391eb..f1e725931 100644 --- a/tests/unit/data/_sync/test_read_rows_acceptance.py +++ b/tests/unit/data/_sync/test_read_rows_acceptance.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + # This file is automatically generated by CrossSync. Do not edit manually. + from __future__ import annotations import os import warnings @@ -25,8 +26,6 @@ from ...v2_client.test_row_merger import ReadRowsTest, TestFile from google.cloud.bigtable.data._sync.cross_sync import CrossSync -__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_read_rows_acceptance" - class TestReadRowsAcceptance: @staticmethod @@ -121,6 +120,7 @@ def test_read_rows_scenario(self, test_case: ReadRowsTest): def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): from google.cloud.bigtable_v2 import ReadRowsResponse + @CrossSync._Sync_Impl.export_sync class mock_stream: def __init__(self, chunk_list): self.chunk_list = chunk_list diff --git a/tests/unit/data/execute_query/_sync/test_query_iterator.py b/tests/unit/data/execute_query/_sync/test_query_iterator.py index 8621d7dd5..f1142d956 100644 --- a/tests/unit/data/execute_query/_sync/test_query_iterator.py +++ b/tests/unit/data/execute_query/_sync/test_query_iterator.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + + # This file is automatically generated by CrossSync. Do not edit manually. + import pytest import concurrent.futures from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse @@ -23,7 +25,6 @@ from unittest import mock except ImportError: import mock -__CROSS_SYNC_OUTPUT__ = "tests.unit.data.execute_query._sync.test_query_iterator" class MockIterator: From 5e3019209d22395af8b8ac0284995714cb452a35 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 13 Sep 2024 13:04:24 -0700 Subject: [PATCH 291/360] removed unused annotation --- tests/unit/data/_async/test_client.py | 2 -- tests/unit/data/_async/test_read_rows_acceptance.py | 1 - tests/unit/data/_sync/test_client.py | 2 -- tests/unit/data/_sync/test_read_rows_acceptance.py | 1 - tests/unit/data/execute_query/_async/test_query_iterator.py | 1 - 5 files changed, 7 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 9f2348323..f1999f802 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1496,7 +1496,6 @@ async def _make_gapic_stream( ): from google.cloud.bigtable_v2 import ReadRowsResponse - @CrossSync.export_sync class mock_stream: def __init__(self, chunk_list, sleep_time): self.chunk_list = chunk_list @@ -3167,7 +3166,6 @@ def _make_client(self, *args, **kwargs): @CrossSync.convert def _make_gapic_stream(self, sample_list: list["ExecuteQueryResponse" | Exception]): - @CrossSync.export_sync class MockStream: def __init__(self, sample_list): self.sample_list = sample_list diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index b1b15c48f..0f275ca8c 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -137,7 +137,6 @@ async def test_read_rows_scenario(self, test_case: ReadRowsTest): async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): from google.cloud.bigtable_v2 import ReadRowsResponse - @CrossSync.export_sync class mock_stream: def __init__(self, chunk_list): self.chunk_list = chunk_list diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 49c052129..45b99a9e3 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -1242,7 +1242,6 @@ def _make_gapic_stream( ): from google.cloud.bigtable_v2 import ReadRowsResponse - @CrossSync._Sync_Impl.export_sync class mock_stream: def __init__(self, chunk_list, sleep_time): self.chunk_list = chunk_list @@ -2721,7 +2720,6 @@ def _make_client(self, *args, **kwargs): return CrossSync._Sync_Impl.TestBigtableDataClient._make_client(*args, **kwargs) def _make_gapic_stream(self, sample_list: list["ExecuteQueryResponse" | Exception]): - @CrossSync._Sync_Impl.export_sync class MockStream: def __init__(self, sample_list): self.sample_list = sample_list diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync/test_read_rows_acceptance.py index f1e725931..61909bf99 100644 --- a/tests/unit/data/_sync/test_read_rows_acceptance.py +++ b/tests/unit/data/_sync/test_read_rows_acceptance.py @@ -120,7 +120,6 @@ def test_read_rows_scenario(self, test_case: ReadRowsTest): def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): from google.cloud.bigtable_v2 import ReadRowsResponse - @CrossSync._Sync_Impl.export_sync class mock_stream: def __init__(self, chunk_list): self.chunk_list = chunk_list diff --git a/tests/unit/data/execute_query/_async/test_query_iterator.py b/tests/unit/data/execute_query/_async/test_query_iterator.py index 9cef2dfb3..814d5e084 100644 --- a/tests/unit/data/execute_query/_async/test_query_iterator.py +++ b/tests/unit/data/execute_query/_async/test_query_iterator.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); From 26315e499d6f3e71a28fd36db7b70265b5ab9f14 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 13 Sep 2024 13:07:05 -0700 Subject: [PATCH 292/360] ran blacken --- tests/unit/data/test_sync_up_to_date.py | 31 +++++++++++-------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py index d170638c4..22dcead74 100644 --- a/tests/unit/data/test_sync_up_to_date.py +++ b/tests/unit/data/test_sync_up_to_date.py @@ -33,9 +33,7 @@ @pytest.mark.skipif( sys.version_info < (3, 9), reason="ast.unparse is only available in 3.9+" ) -@pytest.mark.parametrize( - "sync_file", sync_files, ids=lambda f: f.output_path -) +@pytest.mark.parametrize("sync_file", sync_files, ids=lambda f: f.output_path) def test_sync_up_to_date(sync_file): """ Generate a fresh copy of each cross_sync file, and compare hashes with the existing file. @@ -45,9 +43,7 @@ def test_sync_up_to_date(sync_file): path = sync_file.output_path new_render = sync_file.render(with_black=True, save_to_disk=False) found_render = CrossSyncOutputFile( - output_path="", - ast_tree=ast.parse(open(path).read()), - header=sync_file.header + output_path="", ast_tree=ast.parse(open(path).read()), header=sync_file.header ).render(with_black=True, save_to_disk=False) # compare by content diff = unified_diff(found_render.splitlines(), new_render.splitlines(), lineterm="") @@ -58,23 +54,22 @@ def test_sync_up_to_date(sync_file): found_hash = hashlib.md5(found_render.encode()).hexdigest() assert new_hash == found_hash, f"md5 mismatch for {path}" -@pytest.mark.parametrize( - "sync_file", sync_files, ids=lambda f: f.output_path -) + +@pytest.mark.parametrize("sync_file", sync_files, ids=lambda f: f.output_path) def test_verify_headers(sync_file): license_regex = r""" - \#\ Copyright\ \d{4}\ Google\ LLC\n - \#\n - \#\ Licensed\ under\ the\ Apache\ License,\ Version\ 2\.0\ \(the\ \"License\"\);\n - \#\ you\ may\ not\ use\ this\ file\ except\ in\ compliance\ with\ the\ License\.\n + \#\ Copyright\ \d{4}\ Google\ LLC\n + \#\n + \#\ Licensed\ under\ the\ Apache\ License,\ Version\ 2\.0\ \(the\ \"License\"\);\n + \#\ you\ may\ not\ use\ this\ file\ except\ in\ compliance\ with\ the\ License\.\n \#\ You\ may\ obtain\ a\ copy\ of\ the\ License\ at\ \#\n \#\s+http:\/\/www\.apache\.org\/licenses\/LICENSE-2\.0\n - \#\n - \#\ Unless\ required\ by\ applicable\ law\ or\ agreed\ to\ in\ writing,\ software\n - \#\ distributed\ under\ the\ License\ is\ distributed\ on\ an\ \"AS\ IS\"\ BASIS,\n - \#\ WITHOUT\ WARRANTIES\ OR\ CONDITIONS\ OF\ ANY\ KIND,\ either\ express\ or\ implied\.\n - \#\ See\ the\ License\ for\ the\ specific\ language\ governing\ permissions\ and\n + \#\n + \#\ Unless\ required\ by\ applicable\ law\ or\ agreed\ to\ in\ writing,\ software\n + \#\ distributed\ under\ the\ License\ is\ distributed\ on\ an\ \"AS\ IS\"\ BASIS,\n + \#\ WITHOUT\ WARRANTIES\ OR\ CONDITIONS\ OF\ ANY\ KIND,\ either\ express\ or\ implied\.\n + \#\ See\ the\ License\ for\ the\ specific\ language\ governing\ permissions\ and\n \#\ limitations\ under\ the\ License\. """ pattern = re.compile(license_regex, re.VERBOSE) From a45a6f6599096ea18bc319874763fad4ab51988a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 13 Sep 2024 15:40:35 -0700 Subject: [PATCH 293/360] extracted branch trimming into own transformer --- .cross_sync/transformers.py | 40 ++++++++-- .../strip_async_conditional_branches.yaml | 74 +++++++++++++++++++ .../system/cross_sync/test_cross_sync_e2e.py | 1 + 3 files changed, 107 insertions(+), 8 deletions(-) create mode 100644 tests/system/cross_sync/test_cases/strip_async_conditional_branches.yaml diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index e47eb2f31..3e3f4de82 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -198,6 +198,36 @@ def visit_AsyncFor(self, node): return self.generic_visit(node) +class StripAsyncConditionalBranches(ast.NodeTransformer): + """ + Visits all if statements in an AST, and removes branches marked with CrossSync.is_async + """ + + def visit_If(self, node): + """ + remove CrossSync.is_async branches from top-level if statements + """ + kept_branch = None + # check for CrossSync.is_async + if self._is_async_check(node.test): + kept_branch = node.orelse + # check for not CrossSync.is_async + elif isinstance(node.test, ast.UnaryOp) and isinstance(node.test.op, ast.Not) and self._is_async_check(node.test.operand): + kept_branch = node.body + if kept_branch is not None: + # only keep the statements in the kept branch + return [self.visit(n) for n in kept_branch] + else: + # keep the entire if statement + return self.visit(node) + + def _is_async_check(self, node) -> bool: + """ + Check for CrossSync.is_async nodes + """ + return isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name) and node.value.id == "CrossSync" and node.attr == "is_async" + + class CrossSyncFileProcessor(ast.NodeTransformer): """ Visits a file, looking for __CROSS_SYNC_OUTPUT__ annotations @@ -228,6 +258,8 @@ def visit_Module(self, node): converted = self.generic_visit(node) # strip out CrossSync.rm_aio calls converted = RmAioFunctions().visit(converted) + # strip out CrossSync.is_async branches + converted = StripAsyncConditionalBranches().visit(converted) # replace CrossSync statements converted = SymbolReplacer({"CrossSync": "CrossSync._Sync_Impl"}).visit(converted) return converted @@ -251,14 +283,6 @@ def visit_ClassDef(self, node): continue return self.generic_visit(node) if node else None - def visit_If(self, node): - """ - remove CrossSync.is_async branches from top-level if statements - """ - if isinstance(node.test, ast.Attribute) and isinstance(node.test.value, ast.Name) and node.test.value.id == "CrossSync" and node.test.attr == "is_async": - return [self.generic_visit(n) for n in node.orelse] - return self.generic_visit(node) - def visit_Assign(self, node): """ strip out __CROSS_SYNC_OUTPUT__ assignments diff --git a/tests/system/cross_sync/test_cases/strip_async_conditional_branches.yaml b/tests/system/cross_sync/test_cases/strip_async_conditional_branches.yaml new file mode 100644 index 000000000..0c192fb37 --- /dev/null +++ b/tests/system/cross_sync/test_cases/strip_async_conditional_branches.yaml @@ -0,0 +1,74 @@ +tests: + - description: "top level conditional" + before: | + if CrossSync.is_async: + print("async") + else: + print("sync") + transformers: [StripAsyncConditionalBranches] + after: | + print("sync") + - description: "nested conditional" + before: | + if CrossSync.is_async: + print("async") + else: + print("hello") + if CrossSync.is_async: + print("async") + else: + print("world") + transformers: [StripAsyncConditionalBranches] + after: | + print("hello") + print("world") + - description: "conditional within class" + before: | + class MyClass: + def my_method(self): + if CrossSync.is_async: + return "async result" + else: + return "sync result" + transformers: [StripAsyncConditionalBranches] + after: | + class MyClass: + + def my_method(self): + return "sync result" + - description: "multiple branches" + before: | + if CrossSync.is_async: + print("async branch 1") + elif some_condition: + print("other condition") + elif CrossSync.is_async: + print("async branch 2") + else: + print("sync branch") + transformers: [StripAsyncConditionalBranches] + after: | + if some_condition: + print("other condition") + else: + print("sync branch") + - description: "negated conditionals" + before: | + if not CrossSync.is_async: + print("sync code") + else: + print("async code") + + transformers: [StripAsyncConditionalBranches] + after: | + print("sync code") + - description: "is check" + before: | + if CrossSync.is_async is True: + print("async code") + else: + print("sync code") + + transformers: [StripAsyncConditionalBranches] + after: | + print("sync code") diff --git a/tests/system/cross_sync/test_cross_sync_e2e.py b/tests/system/cross_sync/test_cross_sync_e2e.py index ab0e70162..86911b163 100644 --- a/tests/system/cross_sync/test_cross_sync_e2e.py +++ b/tests/system/cross_sync/test_cross_sync_e2e.py @@ -14,6 +14,7 @@ SymbolReplacer, AsyncToSync, RmAioFunctions, + StripAsyncConditionalBranches, CrossSyncFileProcessor, ) From dd2a65f783406b39b5fdb5389e8afcc2ff7486ef Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 13 Sep 2024 15:43:07 -0700 Subject: [PATCH 294/360] fixed visit to generic visit --- .cross_sync/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index 3e3f4de82..a6801f2d1 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -219,7 +219,7 @@ def visit_If(self, node): return [self.visit(n) for n in kept_branch] else: # keep the entire if statement - return self.visit(node) + return self.generic_visit(node) def _is_async_check(self, node) -> bool: """ From 0d487eb1e81856f2b7147816f011564d1ff48ffa Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 15:32:48 -0700 Subject: [PATCH 295/360] import ExecuteQueryIterator --- google/cloud/bigtable/data/__init__.py | 8 -------- google/cloud/bigtable/data/execute_query/__init__.py | 4 ++++ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index e5fa1d1d7..e4130b8a9 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -59,9 +59,6 @@ ) from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync -from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( - ExecuteQueryIteratorAsync, -) from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( PooledBigtableGrpcTransport, @@ -74,9 +71,6 @@ ) from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation -from google.cloud.bigtable.data.execute_query._sync.execute_query_iterator import ( - ExecuteQueryIterator, -) from google.cloud.bigtable.data._sync.cross_sync import CrossSync @@ -90,8 +84,6 @@ CrossSync._Sync_Impl.add_mapping("_ReadRowsOperation", _ReadRowsOperation) CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) CrossSync._Sync_Impl.add_mapping("_MutateRowsOperation", _MutateRowsOperation) -CrossSync.add_mapping("ExecuteQueryIterator", ExecuteQueryIteratorAsync) -CrossSync._Sync_Impl.add_mapping("ExecuteQueryIterator", ExecuteQueryIterator) __version__: str = package_version.__version__ diff --git a/google/cloud/bigtable/data/execute_query/__init__.py b/google/cloud/bigtable/data/execute_query/__init__.py index 94af7d1cd..0c33c3f28 100644 --- a/google/cloud/bigtable/data/execute_query/__init__.py +++ b/google/cloud/bigtable/data/execute_query/__init__.py @@ -15,6 +15,9 @@ from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( ExecuteQueryIteratorAsync, ) +from google.cloud.bigtable.data.execute_query._sync.execute_query_iterator import ( + ExecuteQueryIterator, +) from google.cloud.bigtable.data.execute_query.metadata import ( Metadata, ProtoMetadata, @@ -35,4 +38,5 @@ "Metadata", "ProtoMetadata", "ExecuteQueryIteratorAsync", + "ExecuteQueryIterator", ] From a4a591b306065472230ef2c7a2b9d78a2faa8b17 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 15:33:54 -0700 Subject: [PATCH 296/360] added docs for sync files --- docs/async_data_client/async_data_table.rst | 6 --- docs/async_data_client/async_data_usage.rst | 18 --------- .../async_client.rst} | 2 +- .../async_execute_query_iterator.rst} | 0 .../async_mutations_batcher.rst} | 0 docs/data_client/async_table.rst | 6 +++ docs/data_client/client.rst | 6 +++ .../exceptions.rst} | 0 docs/data_client/execute_query_iterator.rst | 6 +++ .../execute_query_metadata.rst} | 0 .../execute_query_values.rst} | 0 .../mutations.rst} | 0 docs/data_client/mutations_batcher.rst | 6 +++ .../read_modify_write_rules.rst} | 0 .../read_rows_query.rst} | 0 .../row.rst} | 0 .../row_filters.rst} | 0 docs/data_client/table.rst | 6 +++ docs/data_client/usage.rst | 39 +++++++++++++++++++ docs/index.rst | 2 +- 20 files changed, 71 insertions(+), 26 deletions(-) delete mode 100644 docs/async_data_client/async_data_table.rst delete mode 100644 docs/async_data_client/async_data_usage.rst rename docs/{async_data_client/async_data_client.rst => data_client/async_client.rst} (52%) rename docs/{async_data_client/async_data_execute_query_iterator.rst => data_client/async_execute_query_iterator.rst} (100%) rename docs/{async_data_client/async_data_mutations_batcher.rst => data_client/async_mutations_batcher.rst} (100%) create mode 100644 docs/data_client/async_table.rst create mode 100644 docs/data_client/client.rst rename docs/{async_data_client/async_data_exceptions.rst => data_client/exceptions.rst} (100%) create mode 100644 docs/data_client/execute_query_iterator.rst rename docs/{async_data_client/async_data_execute_query_metadata.rst => data_client/execute_query_metadata.rst} (100%) rename docs/{async_data_client/async_data_execute_query_values.rst => data_client/execute_query_values.rst} (100%) rename docs/{async_data_client/async_data_mutations.rst => data_client/mutations.rst} (100%) create mode 100644 docs/data_client/mutations_batcher.rst rename docs/{async_data_client/async_data_read_modify_write_rules.rst => data_client/read_modify_write_rules.rst} (100%) rename docs/{async_data_client/async_data_read_rows_query.rst => data_client/read_rows_query.rst} (100%) rename docs/{async_data_client/async_data_row.rst => data_client/row.rst} (100%) rename docs/{async_data_client/async_data_row_filters.rst => data_client/row_filters.rst} (100%) create mode 100644 docs/data_client/table.rst create mode 100644 docs/data_client/usage.rst diff --git a/docs/async_data_client/async_data_table.rst b/docs/async_data_client/async_data_table.rst deleted file mode 100644 index a977beb6a..000000000 --- a/docs/async_data_client/async_data_table.rst +++ /dev/null @@ -1,6 +0,0 @@ -Table Async -~~~~~~~~~~~ - -.. autoclass:: google.cloud.bigtable.data._async.client.TableAsync - :members: - :show-inheritance: diff --git a/docs/async_data_client/async_data_usage.rst b/docs/async_data_client/async_data_usage.rst deleted file mode 100644 index 61d5837fd..000000000 --- a/docs/async_data_client/async_data_usage.rst +++ /dev/null @@ -1,18 +0,0 @@ -Async Data Client -================= - -.. toctree:: - :maxdepth: 2 - - async_data_client - async_data_table - async_data_mutations_batcher - async_data_read_rows_query - async_data_row - async_data_row_filters - async_data_mutations - async_data_read_modify_write_rules - async_data_exceptions - async_data_execute_query_iterator - async_data_execute_query_values - async_data_execute_query_metadata diff --git a/docs/async_data_client/async_data_client.rst b/docs/data_client/async_client.rst similarity index 52% rename from docs/async_data_client/async_data_client.rst rename to docs/data_client/async_client.rst index c5cc70740..70c61e676 100644 --- a/docs/async_data_client/async_data_client.rst +++ b/docs/data_client/async_client.rst @@ -1,6 +1,6 @@ Bigtable Data Client Async ~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: google.cloud.bigtable.data._async.client.BigtableDataClientAsync +.. autoclass:: google.cloud.bigtable.data.BigtableDataClientAsync :members: :show-inheritance: diff --git a/docs/async_data_client/async_data_execute_query_iterator.rst b/docs/data_client/async_execute_query_iterator.rst similarity index 100% rename from docs/async_data_client/async_data_execute_query_iterator.rst rename to docs/data_client/async_execute_query_iterator.rst diff --git a/docs/async_data_client/async_data_mutations_batcher.rst b/docs/data_client/async_mutations_batcher.rst similarity index 100% rename from docs/async_data_client/async_data_mutations_batcher.rst rename to docs/data_client/async_mutations_batcher.rst diff --git a/docs/data_client/async_table.rst b/docs/data_client/async_table.rst new file mode 100644 index 000000000..8b069184a --- /dev/null +++ b/docs/data_client/async_table.rst @@ -0,0 +1,6 @@ +Table +~~~~~ + +.. autoclass:: google.cloud.bigtable.data.TableAsync + :members: + :show-inheritance: diff --git a/docs/data_client/client.rst b/docs/data_client/client.rst new file mode 100644 index 000000000..cf7c00dad --- /dev/null +++ b/docs/data_client/client.rst @@ -0,0 +1,6 @@ +Bigtable Data Client +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: google.cloud.bigtable.data.BigtableDataClient + :members: + :show-inheritance: diff --git a/docs/async_data_client/async_data_exceptions.rst b/docs/data_client/exceptions.rst similarity index 100% rename from docs/async_data_client/async_data_exceptions.rst rename to docs/data_client/exceptions.rst diff --git a/docs/data_client/execute_query_iterator.rst b/docs/data_client/execute_query_iterator.rst new file mode 100644 index 000000000..6eb9f84db --- /dev/null +++ b/docs/data_client/execute_query_iterator.rst @@ -0,0 +1,6 @@ +Execute Query Iterator +~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: google.cloud.bigtable.data.execute_query.ExecuteQueryIterator + :members: + :show-inheritance: diff --git a/docs/async_data_client/async_data_execute_query_metadata.rst b/docs/data_client/execute_query_metadata.rst similarity index 100% rename from docs/async_data_client/async_data_execute_query_metadata.rst rename to docs/data_client/execute_query_metadata.rst diff --git a/docs/async_data_client/async_data_execute_query_values.rst b/docs/data_client/execute_query_values.rst similarity index 100% rename from docs/async_data_client/async_data_execute_query_values.rst rename to docs/data_client/execute_query_values.rst diff --git a/docs/async_data_client/async_data_mutations.rst b/docs/data_client/mutations.rst similarity index 100% rename from docs/async_data_client/async_data_mutations.rst rename to docs/data_client/mutations.rst diff --git a/docs/data_client/mutations_batcher.rst b/docs/data_client/mutations_batcher.rst new file mode 100644 index 000000000..b21a193d1 --- /dev/null +++ b/docs/data_client/mutations_batcher.rst @@ -0,0 +1,6 @@ +Mutations Batcher +~~~~~~~~~~~~~~~~~ + +.. automodule:: google.cloud.bigtable.data._sync.mutations_batcher + :members: + :show-inheritance: diff --git a/docs/async_data_client/async_data_read_modify_write_rules.rst b/docs/data_client/read_modify_write_rules.rst similarity index 100% rename from docs/async_data_client/async_data_read_modify_write_rules.rst rename to docs/data_client/read_modify_write_rules.rst diff --git a/docs/async_data_client/async_data_read_rows_query.rst b/docs/data_client/read_rows_query.rst similarity index 100% rename from docs/async_data_client/async_data_read_rows_query.rst rename to docs/data_client/read_rows_query.rst diff --git a/docs/async_data_client/async_data_row.rst b/docs/data_client/row.rst similarity index 100% rename from docs/async_data_client/async_data_row.rst rename to docs/data_client/row.rst diff --git a/docs/async_data_client/async_data_row_filters.rst b/docs/data_client/row_filters.rst similarity index 100% rename from docs/async_data_client/async_data_row_filters.rst rename to docs/data_client/row_filters.rst diff --git a/docs/data_client/table.rst b/docs/data_client/table.rst new file mode 100644 index 000000000..95c91eb27 --- /dev/null +++ b/docs/data_client/table.rst @@ -0,0 +1,6 @@ +Table +~~~~~ + +.. autoclass:: google.cloud.bigtable.data.Table + :members: + :show-inheritance: diff --git a/docs/data_client/usage.rst b/docs/data_client/usage.rst new file mode 100644 index 000000000..e6ebb81ec --- /dev/null +++ b/docs/data_client/usage.rst @@ -0,0 +1,39 @@ +Data Client +=========== + +Async Surface +------------- + +.. toctree:: + :maxdepth: 2 + + async_client + async_table + async_mutations_batcher + async_execute_query_iterator + +Sync Surface +------------ + +.. toctree:: + :maxdepth: 2 + + client + table + mutations_batcher + execute_query_iterator + +Common Classes +-------------- + +.. toctree:: + :maxdepth: 2 + + read_rows_query + row + row_filters + mutations + read_modify_write_rules + exceptions + execute_query_values + execute_query_metadata diff --git a/docs/index.rst b/docs/index.rst index 4204e981d..87f2db9d6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,8 +7,8 @@ Client Types .. toctree:: :maxdepth: 2 + data_client/usage classic_client/usage - async_data_client/async_data_usage Changelog From 5736661a7b396969d3b35d56967848ac8e4964a7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 15:43:54 -0700 Subject: [PATCH 297/360] add execute query iterators to cross sync mappings --- google/cloud/bigtable/data/execute_query/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/execute_query/__init__.py b/google/cloud/bigtable/data/execute_query/__init__.py index 0c33c3f28..848fc8481 100644 --- a/google/cloud/bigtable/data/execute_query/__init__.py +++ b/google/cloud/bigtable/data/execute_query/__init__.py @@ -28,7 +28,9 @@ QueryResultRow, Struct, ) - +from google.cloud.bigtable.data._sync.cross_sync import CrossSync +CrossSync.add_mapping("ExecuteQueryIterator", ExecuteQueryIteratorAsync) +CrossSync._Sync_Impl.add_mapping("ExecuteQueryIterator", ExecuteQueryIterator) __all__ = [ "ExecuteQueryValueType", From afe25fe5d3058d57bdf5c5c4fb6369a01a159562 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 15:44:37 -0700 Subject: [PATCH 298/360] regenerated sync --- tests/unit/data/_sync/test_client.py | 12 ++++-------- tests/unit/data/_sync/test_mutations_batcher.py | 6 ++---- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 45b99a9e3..62714ec68 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -703,16 +703,14 @@ def test__multiple_table_registration(self): assert id(table_1) in client._instance_owners[instance_1_key] with client.get_table("instance_1", "table_1") as table_2: assert table_2._register_instance_future is not None - if not CrossSync._Sync_Impl.is_async: - table_2._register_instance_future.result() + table_2._register_instance_future.result() assert len(client._instance_owners[instance_1_key]) == 2 assert len(client._active_instances) == 1 assert id(table_1) in client._instance_owners[instance_1_key] assert id(table_2) in client._instance_owners[instance_1_key] with client.get_table("instance_1", "table_3") as table_3: assert table_3._register_instance_future is not None - if not CrossSync._Sync_Impl.is_async: - table_3._register_instance_future.result() + table_3._register_instance_future.result() instance_3_path = client._gapic_client.instance_path( client.project, "instance_1" ) @@ -740,12 +738,10 @@ def test__multiple_instance_registration(self): with self._make_client(project="project-id") as client: with client.get_table("instance_1", "table_1") as table_1: assert table_1._register_instance_future is not None - if not CrossSync._Sync_Impl.is_async: - table_1._register_instance_future.result() + table_1._register_instance_future.result() with client.get_table("instance_2", "table_2") as table_2: assert table_2._register_instance_future is not None - if not CrossSync._Sync_Impl.is_async: - table_2._register_instance_future.result() + table_2._register_instance_future.result() instance_1_path = client._gapic_client.instance_path( client.project, "instance_1" ) diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync/test_mutations_batcher.py index 7b48b6682..2b7ba8a67 100644 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ b/tests/unit/data/_sync/test_mutations_batcher.py @@ -662,8 +662,7 @@ def test_schedule_flush_with_mutations(self): """if new mutations exist, should add a new flush task to _flush_jobs""" with self._make_one() as instance: with mock.patch.object(instance, "_flush_internal") as flush_mock: - if not CrossSync._Sync_Impl.is_async: - flush_mock.side_effect = lambda x: time.sleep(0.1) + flush_mock.side_effect = lambda x: time.sleep(0.1) for i in range(1, 4): mutation = mock.Mock() instance._staged_entries = [mutation] @@ -708,8 +707,7 @@ def test_flush_clears_job_list(self): with mock.patch.object( instance, "_flush_internal", CrossSync._Sync_Impl.Mock() ) as flush_mock: - if not CrossSync._Sync_Impl.is_async: - flush_mock.side_effect = lambda x: time.sleep(0.1) + flush_mock.side_effect = lambda x: time.sleep(0.1) mutations = [self._make_mutation(count=1, size=1)] instance._staged_entries = mutations assert instance._flush_jobs == set() From 59a5df2d37183975723629dc5c8bba01e4dd3f56 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 15:45:20 -0700 Subject: [PATCH 299/360] update transformer value access --- .cross_sync/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index a6801f2d1..d1ba08a58 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -248,7 +248,7 @@ def get_output_path(self, node): for target in n.targets: if isinstance(target, ast.Name) and target.id == self.FILE_ANNOTATION: # return the output path - return n.value.value.replace(".", "/") + ".py" + return n.value.s.replace(".", "/") + ".py" def visit_Module(self, node): # look for __CROSS_SYNC_OUTPUT__ Assign statement From 62d46ec3f8e5d7ea1225b7bb937c80bbb1ff5561 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 16:03:54 -0700 Subject: [PATCH 300/360] fixed py37 test --- tests/unit/data/_async/test_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index f1999f802..4494a6693 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -354,8 +354,8 @@ async def test__ping_and_warm_instances(self): client_mock, channel ) assert len(result) == 0 - assert gather.call_args.kwargs["return_exceptions"] is True - assert gather.call_args.kwargs["sync_executor"] == client_mock._executor + assert gather.call_args[1]["return_exceptions"] is True + assert gather.call_args[1]["sync_executor"] == client_mock._executor # test with instances client_mock._active_instances = [ (mock.Mock(), mock.Mock(), mock.Mock()) @@ -1409,7 +1409,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ except Exception: # we expect an exception from attempting to call the mock pass - kwargs = gapic_mock.call_args_list[0].kwargs + kwargs = gapic_mock.call_args_list[0][1] metadata = kwargs["metadata"] goog_metadata = None for key, value in metadata: @@ -2464,7 +2464,7 @@ async def test_mutate_row_metadata(self, include_app_profile): client._gapic_client, "mutate_row", CrossSync.Mock() ) as read_rows: await table.mutate_row("rk", mock.Mock()) - kwargs = read_rows.call_args_list[0].kwargs + kwargs = read_rows.call_args_list[0][1] metadata = kwargs["metadata"] goog_metadata = None for key, value in metadata: From c3675ce61120f66f163e5e8c26c320b44eb91194 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 16:05:09 -0700 Subject: [PATCH 301/360] regenerate files --- tests/unit/data/_sync/test_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync/test_client.py index 62714ec68..205d10b04 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync/test_client.py @@ -280,8 +280,8 @@ def test__ping_and_warm_instances(self): client_mock, channel ) assert len(result) == 0 - assert gather.call_args.kwargs["return_exceptions"] is True - assert gather.call_args.kwargs["sync_executor"] == client_mock._executor + assert gather.call_args[1]["return_exceptions"] is True + assert gather.call_args[1]["sync_executor"] == client_mock._executor client_mock._active_instances = [ (mock.Mock(), mock.Mock(), mock.Mock()) ] * 4 @@ -1158,7 +1158,7 @@ def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): [i for i in maybe_stream] except Exception: pass - kwargs = gapic_mock.call_args_list[0].kwargs + kwargs = gapic_mock.call_args_list[0][1] metadata = kwargs["metadata"] goog_metadata = None for key, value in metadata: @@ -2088,7 +2088,7 @@ def test_mutate_row_metadata(self, include_app_profile): client._gapic_client, "mutate_row", CrossSync._Sync_Impl.Mock() ) as read_rows: table.mutate_row("rk", mock.Mock()) - kwargs = read_rows.call_args_list[0].kwargs + kwargs = read_rows.call_args_list[0][1] metadata = kwargs["metadata"] goog_metadata = None for key, value in metadata: From 7dc78e12580aef1e797f2ee838130e42d8ca3bb1 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 16:18:04 -0700 Subject: [PATCH 302/360] fixed docfx --- docs/data_client/async_table.rst | 4 ++-- docs/scripts/patch_devsite_toc.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/data_client/async_table.rst b/docs/data_client/async_table.rst index 8b069184a..05ffb8fad 100644 --- a/docs/data_client/async_table.rst +++ b/docs/data_client/async_table.rst @@ -1,5 +1,5 @@ -Table -~~~~~ +Table Async +~~~~~~~~~~~ .. autoclass:: google.cloud.bigtable.data.TableAsync :members: diff --git a/docs/scripts/patch_devsite_toc.py b/docs/scripts/patch_devsite_toc.py index 6338128dd..c2cb14856 100644 --- a/docs/scripts/patch_devsite_toc.py +++ b/docs/scripts/patch_devsite_toc.py @@ -105,7 +105,8 @@ def __init__(self, dir_name, index_file_name): continue # bail when toc indented block is done if not line.startswith(" ") and not line.startswith("\t"): - break + in_toc = False + continue # extract entries self.items.append(self.extract_toc_entry(line.strip())) @@ -179,7 +180,7 @@ def validate_toc(toc_file_path, expected_section_list, added_sections): toc_path = "_build/html/docfx_yaml/toc.yml" custom_sections = [ TocSection( - dir_name="async_data_client", index_file_name="async_data_usage.rst" + dir_name="data_client", index_file_name="usage.rst" ), TocSection(dir_name="classic_client", index_file_name="usage.rst"), ] @@ -194,7 +195,7 @@ def validate_toc(toc_file_path, expected_section_list, added_sections): "bigtable APIs", "Changelog", "Multiprocessing", - "Async Data Client", + "Data Client", "Classic Client", ], added_sections=custom_sections, From 972513140758c09a1f647e378d0f1b6965bc993c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 16:34:42 -0700 Subject: [PATCH 303/360] increase max depth --- docs/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index 87f2db9d6..ee4a89f7e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -5,7 +5,7 @@ Client Types ------------- .. toctree:: - :maxdepth: 2 + :maxdepth: 3 data_client/usage classic_client/usage From 5a0fbbab5c111abd5d4228d9d3c22fc454d812d5 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 16:34:49 -0700 Subject: [PATCH 304/360] put sync on top --- docs/data_client/usage.rst | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/data_client/usage.rst b/docs/data_client/usage.rst index e6ebb81ec..8edc424ec 100644 --- a/docs/data_client/usage.rst +++ b/docs/data_client/usage.rst @@ -1,33 +1,33 @@ Data Client =========== -Async Surface -------------- - -.. toctree:: - :maxdepth: 2 - - async_client - async_table - async_mutations_batcher - async_execute_query_iterator - Sync Surface ------------ .. toctree:: - :maxdepth: 2 + :maxdepth: 3 client table mutations_batcher execute_query_iterator +Async Surface +------------- + +.. toctree:: + :maxdepth: 3 + + async_client + async_table + async_mutations_batcher + async_execute_query_iterator + Common Classes -------------- .. toctree:: - :maxdepth: 2 + :maxdepth: 3 read_rows_query row From fd765a7ba56e12d6ee785bee00adc2686cc186ab Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 16:44:42 -0700 Subject: [PATCH 305/360] added notes to README --- .cross_sync/README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.cross_sync/README.md b/.cross_sync/README.md index 5f7a62581..4343b1c47 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -62,10 +62,14 @@ CrossSync provides a set of annotations to mark up async classes, to guide the g ### Code Generation -Generation can be initiated using `python .cross_sync/generate.py .` +Generation can be initiated using `nox -s generate_sync` from the root of the project. This will find all classes with the `__CROSS_SYNC_OUTPUT__ = "path/to/output"` annotation, and generate a sync version of classes marked with `@CrossSync.export_sync` at the output path. +This will also happen automatically as part of the Owlbot CI step. + +There is a unit test at `tests/unit/data/test_sync_up_to_date.py` that verifies that the generated code is up to date + ## Architecture CrossSync is made up of two parts: From 487f8b20691ce6b306c812e51526324940defc2e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 16:57:41 -0700 Subject: [PATCH 306/360] generate into _sync_autogen directories --- google/cloud/bigtable/data/_async/_mutate_rows.py | 2 +- google/cloud/bigtable/data/_async/_read_rows.py | 2 +- google/cloud/bigtable/data/_async/client.py | 2 +- google/cloud/bigtable/data/_async/mutations_batcher.py | 2 +- .../data/execute_query/_async/execute_query_iterator.py | 2 +- tests/system/data/test_system_async.py | 2 +- tests/unit/data/_async/test__mutate_rows.py | 2 +- tests/unit/data/_async/test__read_rows.py | 2 +- tests/unit/data/_async/test_client.py | 2 +- tests/unit/data/_async/test_mutations_batcher.py | 2 +- tests/unit/data/_async/test_read_rows_acceptance.py | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index ce908ee26..f39e932f6 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -44,7 +44,7 @@ ) from google.cloud.bigtable.data._sync.client import Table as TableType # type: ignore -__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync._mutate_rows" +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._mutate_rows" @CrossSync.convert_class("_MutateRowsOperation") diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index a68660daa..73050f9f5 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -42,7 +42,7 @@ else: from google.cloud.bigtable.data._sync.client import Table as TableType # type: ignore -__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync._read_rows" +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._read_rows" @CrossSync.convert_class("_ReadRowsOperation") diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 72b2792c0..1e283ddab 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -106,7 +106,7 @@ from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery -__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync.client" +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen.client" @CrossSync.convert_class( diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 603f0886a..74fc6175d 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -42,7 +42,7 @@ else: from google.cloud.bigtable.data._sync.client import Table as TableType # type: ignore -__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync.mutations_batcher" +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen.mutations_batcher" @CrossSync.convert_class(sync_name="_FlowControl", add_mapping_for_name="_FlowControl") diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index 367a2925e..f58c7b081 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -51,7 +51,7 @@ from google.cloud.bigtable.data import BigtableDataClient as DataClientType __CROSS_SYNC_OUTPUT__ = ( - "google.cloud.bigtable.data.execute_query._sync.execute_query_iterator" + "google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator" ) diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index ad9c973d2..99667848f 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -27,7 +27,7 @@ from . import TEST_FAMILY, TEST_FAMILY_2 -__CROSS_SYNC_OUTPUT__ = "tests.system.data.test_system" +__CROSS_SYNC_OUTPUT__ = "tests.system.data.test_system_autogen" @CrossSync.convert_class( diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 13a30cb37..ab35ed98c 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -27,7 +27,7 @@ except ImportError: # pragma: NO COVER import mock # type: ignore -__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test__mutate_rows" +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test__mutate_rows" @CrossSync.convert_class("TestMutateRowsOperation") diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 4d63a0fb1..e9a27fca0 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -22,7 +22,7 @@ import mock # type: ignore -__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test__read_rows" +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test__read_rows" @CrossSync.convert_class( diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 4494a6693..35f88fb02 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -47,7 +47,7 @@ CrossSync.add_mapping("grpc_helpers", grpc_helpers) -__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_client" +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_client" @CrossSync.convert_class( diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 2c0c103a1..4fb3559e3 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -23,7 +23,7 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_mutations_batcher" +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_mutations_batcher" @CrossSync.convert_class(sync_name="Test_FlowControl") diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 0f275ca8c..493e62bb9 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -30,7 +30,7 @@ from google.cloud.bigtable.data._sync.cross_sync import CrossSync -__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync.test_read_rows_acceptance" +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_read_rows_acceptance" @CrossSync.convert_class( From 77560c923b6ff94d5ef1532afbe8a033a8f8a3f8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 17:00:41 -0700 Subject: [PATCH 307/360] moved cross_sync directory --- .cross_sync/README.md | 2 +- .cross_sync/transformers.py | 2 +- .../bigtable/data/{_sync/cross_sync => _cross_sync}/__init__.py | 0 .../data/{_sync/cross_sync => _cross_sync}/_decorators.py | 0 .../data/{_sync/cross_sync => _cross_sync}/_mapping_meta.py | 0 .../data/{_sync/cross_sync => _cross_sync}/cross_sync.py | 0 6 files changed, 2 insertions(+), 2 deletions(-) rename google/cloud/bigtable/data/{_sync/cross_sync => _cross_sync}/__init__.py (100%) rename google/cloud/bigtable/data/{_sync/cross_sync => _cross_sync}/_decorators.py (100%) rename google/cloud/bigtable/data/{_sync/cross_sync => _cross_sync}/_mapping_meta.py (100%) rename google/cloud/bigtable/data/{_sync/cross_sync => _cross_sync}/cross_sync.py (100%) diff --git a/.cross_sync/README.md b/.cross_sync/README.md index 5f7a62581..6bc2a4f69 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -69,5 +69,5 @@ annotation, and generate a sync version of classes marked with `@CrossSync.expor ## Architecture CrossSync is made up of two parts: -- the runtime shims and annotations live in `/google/cloud/bigtable/_sync/_cross_sync` +- the runtime shims and annotations live in `/google/cloud/bigtable/_cross_sync` - the code generation logic lives in `/.cross_sync/` in the repo root diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index d1ba08a58..d40614497 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -31,7 +31,7 @@ import sys # add cross_sync to path -sys.path.append("google/cloud/bigtable/data/_sync/cross_sync") +sys.path.append("google/cloud/bigtable/data/_cross_sync") from _decorators import AstDecorator, ConvertClass diff --git a/google/cloud/bigtable/data/_sync/cross_sync/__init__.py b/google/cloud/bigtable/data/_cross_sync/__init__.py similarity index 100% rename from google/cloud/bigtable/data/_sync/cross_sync/__init__.py rename to google/cloud/bigtable/data/_cross_sync/__init__.py diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_cross_sync/_decorators.py similarity index 100% rename from google/cloud/bigtable/data/_sync/cross_sync/_decorators.py rename to google/cloud/bigtable/data/_cross_sync/_decorators.py diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_mapping_meta.py b/google/cloud/bigtable/data/_cross_sync/_mapping_meta.py similarity index 100% rename from google/cloud/bigtable/data/_sync/cross_sync/_mapping_meta.py rename to google/cloud/bigtable/data/_cross_sync/_mapping_meta.py diff --git a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py b/google/cloud/bigtable/data/_cross_sync/cross_sync.py similarity index 100% rename from google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py rename to google/cloud/bigtable/data/_cross_sync/cross_sync.py From 252ddef037dd7219f6637105166ff892496c3cf8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 17:01:19 -0700 Subject: [PATCH 308/360] fixed outdated name --- .cross_sync/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.cross_sync/README.md b/.cross_sync/README.md index 6bc2a4f69..4214e0d78 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -64,7 +64,7 @@ CrossSync provides a set of annotations to mark up async classes, to guide the g Generation can be initiated using `python .cross_sync/generate.py .` from the root of the project. This will find all classes with the `__CROSS_SYNC_OUTPUT__ = "path/to/output"` -annotation, and generate a sync version of classes marked with `@CrossSync.export_sync` at the output path. +annotation, and generate a sync version of classes marked with `@CrossSync.convert_sync` at the output path. ## Architecture From 3c806b55bae80c816149057bc18dc60b0133bf40 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 17:03:57 -0700 Subject: [PATCH 309/360] fixed import path --- google/cloud/bigtable/data/_async/_mutate_rows.py | 2 +- google/cloud/bigtable/data/_async/_read_rows.py | 2 +- google/cloud/bigtable/data/_async/client.py | 2 +- google/cloud/bigtable/data/_async/mutations_batcher.py | 2 +- tests/system/data/test_system_async.py | 2 +- tests/unit/data/_async/test__mutate_rows.py | 2 +- tests/unit/data/_async/test__read_rows.py | 2 +- tests/unit/data/_async/test_client.py | 4 ++-- tests/unit/data/_async/test_mutations_batcher.py | 4 ++-- tests/unit/data/_async/test_read_rows_acceptance.py | 2 +- 10 files changed, 12 insertions(+), 12 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index f39e932f6..db2d0ac7f 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -28,7 +28,7 @@ from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT from google.cloud.bigtable.data.mutations import _EntryWithProto -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 73050f9f5..8ba34abc6 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -34,7 +34,7 @@ from google.api_core import retry as retries from google.api_core.retry import exponential_sleep_generator -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync if TYPE_CHECKING: if CrossSync.is_async: diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 1e283ddab..da6220b8d 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -79,7 +79,7 @@ from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter from google.cloud.bigtable.data.row_filters import RowFilterChain -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync if CrossSync.is_async: from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc_asyncio import ( diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 74fc6175d..2296b9a78 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -32,7 +32,7 @@ ) from google.cloud.bigtable.data.mutations import Mutation -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 99667848f..ee05acee7 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -22,7 +22,7 @@ from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE from google.cloud.environment_vars import BIGTABLE_EMULATOR -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync from . import TEST_FAMILY, TEST_FAMILY_2 diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index ab35ed98c..3b3e9591b 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -19,7 +19,7 @@ from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import Forbidden -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync # try/except added for compatibility with python < 3.8 try: diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index e9a27fca0..e460835a1 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -13,7 +13,7 @@ import pytest -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync # try/except added for compatibility with python < 3.8 try: diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 35f88fb02..993e75882 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -34,7 +34,7 @@ from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync if CrossSync.is_async: from google.api_core import grpc_helpers_async @@ -1343,7 +1343,7 @@ async def test_customizable_retryable_errors( else: retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" with mock.patch( - f"google.cloud.bigtable.data._sync.cross_sync.{retry_fn}" + f"google.cloud.bigtable.data._cross_sync.{retry_fn}" ) as retry_fn_mock: async with self._make_client() as client: table = client.get_table("instance-id", "table-id") diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 4fb3559e3..2d0f80ee3 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -21,7 +21,7 @@ from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.cloud.bigtable.data import TABLE_DEFAULT -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_mutations_batcher" @@ -535,7 +535,7 @@ async def test__start_flush_timer_call_when_closed( @pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test__flush_timer(self, num_staged): """Timer should continue to call _schedule_flush in a loop""" - from google.cloud.bigtable.data._sync.cross_sync import CrossSync + from google.cloud.bigtable.data._cross_sync import CrossSync with mock.patch.object( self._get_target_class(), "_schedule_flush" diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 493e62bb9..6ec783069 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -27,7 +27,7 @@ from ...v2_client.test_row_merger import ReadRowsTest, TestFile -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync __CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_read_rows_acceptance" From fff417c442488a90f143cc3cb47f1bf60b28abf0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 17:10:34 -0700 Subject: [PATCH 310/360] fixed import --- .../data/execute_query/_async/execute_query_iterator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index f58c7b081..d58312fa5 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -42,7 +42,7 @@ ExecuteQueryRequest as ExecuteQueryRequestPB, ) -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync if TYPE_CHECKING: if CrossSync.is_async: From e894845bb77415781cfb0a463dc6a3249a655d95 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 17:14:35 -0700 Subject: [PATCH 311/360] moved files to _sync_autogen --- google/cloud/bigtable/data/__init__.py | 12 ++++++------ .../data/{_sync => _sync_autogen}/_mutate_rows.py | 2 +- .../data/{_sync => _sync_autogen}/_read_rows.py | 2 +- .../bigtable/data/{_sync => _sync_autogen}/client.py | 2 +- .../{_sync => _sync_autogen}/mutations_batcher.py | 2 +- google/cloud/bigtable/data/execute_query/__init__.py | 4 ++-- .../execute_query_iterator.py | 2 +- .../data/{test_system.py => test_system_autogen.py} | 2 +- .../data/{_sync => _cross_sync}/test_cross_sync.py | 0 .../test_cross_sync_decorators.py | 0 tests/unit/data/{_sync => _sync_autogen}/__init__.py | 0 .../{_sync => _sync_autogen}/test__mutate_rows.py | 2 +- .../data/{_sync => _sync_autogen}/test__read_rows.py | 2 +- .../data/{_sync => _sync_autogen}/test_client.py | 4 ++-- .../test_mutations_batcher.py | 4 ++-- .../test_read_rows_acceptance.py | 2 +- 16 files changed, 21 insertions(+), 21 deletions(-) rename google/cloud/bigtable/data/{_sync => _sync_autogen}/_mutate_rows.py (99%) rename google/cloud/bigtable/data/{_sync => _sync_autogen}/_read_rows.py (99%) rename google/cloud/bigtable/data/{_sync => _sync_autogen}/client.py (99%) rename google/cloud/bigtable/data/{_sync => _sync_autogen}/mutations_batcher.py (99%) rename google/cloud/bigtable/data/execute_query/{_sync => _sync_autogen}/execute_query_iterator.py (99%) rename tests/system/data/{test_system.py => test_system_autogen.py} (99%) rename tests/unit/data/{_sync => _cross_sync}/test_cross_sync.py (100%) rename tests/unit/data/{_sync => _cross_sync}/test_cross_sync_decorators.py (100%) rename tests/unit/data/{_sync => _sync_autogen}/__init__.py (100%) rename tests/unit/data/{_sync => _sync_autogen}/test__mutate_rows.py (99%) rename tests/unit/data/{_sync => _sync_autogen}/test__read_rows.py (99%) rename tests/unit/data/{_sync => _sync_autogen}/test_client.py (99%) rename tests/unit/data/{_sync => _sync_autogen}/test_mutations_batcher.py (99%) rename tests/unit/data/{_sync => _sync_autogen}/test_read_rows_acceptance.py (99%) diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index e4130b8a9..f2a1aaab1 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -18,9 +18,9 @@ from google.cloud.bigtable.data._async.client import BigtableDataClientAsync from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync -from google.cloud.bigtable.data._sync.client import BigtableDataClient -from google.cloud.bigtable.data._sync.client import Table -from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher +from google.cloud.bigtable.data._sync_autogen.client import BigtableDataClient +from google.cloud.bigtable.data._sync_autogen.client import Table +from google.cloud.bigtable.data._sync_autogen.mutations_batcher import MutationsBatcher from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.cloud.bigtable.data.read_rows_query import RowRange @@ -69,10 +69,10 @@ from google.cloud.bigtable_v2.services.bigtable.client import ( BigtableClient, ) -from google.cloud.bigtable.data._sync._read_rows import _ReadRowsOperation -from google.cloud.bigtable.data._sync._mutate_rows import _MutateRowsOperation +from google.cloud.bigtable.data._sync_autogen._read_rows import _ReadRowsOperation +from google.cloud.bigtable.data._sync_autogen._mutate_rows import _MutateRowsOperation -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync CrossSync.add_mapping("GapicClient", BigtableAsyncClient) CrossSync._Sync_Impl.add_mapping("GapicClient", BigtableClient) diff --git a/google/cloud/bigtable/data/_sync/_mutate_rows.py b/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py similarity index 99% rename from google/cloud/bigtable/data/_sync/_mutate_rows.py rename to google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py index 232ddced9..55d4a47c5 100644 --- a/google/cloud/bigtable/data/_sync/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py @@ -26,7 +26,7 @@ from google.cloud.bigtable.data._helpers import _retry_exception_factory from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT from google.cloud.bigtable.data.mutations import _EntryWithProto -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry diff --git a/google/cloud/bigtable/data/_sync/_read_rows.py b/google/cloud/bigtable/data/_sync_autogen/_read_rows.py similarity index 99% rename from google/cloud/bigtable/data/_sync/_read_rows.py rename to google/cloud/bigtable/data/_sync_autogen/_read_rows.py index 05254279e..a5e817100 100644 --- a/google/cloud/bigtable/data/_sync/_read_rows.py +++ b/google/cloud/bigtable/data/_sync_autogen/_read_rows.py @@ -32,7 +32,7 @@ from google.cloud.bigtable.data._helpers import _retry_exception_factory from google.api_core import retry as retries from google.api_core.retry import exponential_sleep_generator -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync if TYPE_CHECKING: from google.cloud.bigtable.data._sync.client import Table as TableType diff --git a/google/cloud/bigtable/data/_sync/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py similarity index 99% rename from google/cloud/bigtable/data/_sync/client.py rename to google/cloud/bigtable/data/_sync_autogen/client.py index d6c00c91e..257b44f07 100644 --- a/google/cloud/bigtable/data/_sync/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -66,7 +66,7 @@ from google.cloud.bigtable.data.row_filters import StripValueTransformerFilter from google.cloud.bigtable.data.row_filters import CellsRowLimitFilter from google.cloud.bigtable.data.row_filters import RowFilterChain -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync from typing import Iterable from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( PooledBigtableGrpcTransport as PooledTransportType, diff --git a/google/cloud/bigtable/data/_sync/mutations_batcher.py b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py similarity index 99% rename from google/cloud/bigtable/data/_sync/mutations_batcher.py rename to google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py index 800774d36..714f0a946 100644 --- a/google/cloud/bigtable/data/_sync/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py @@ -29,7 +29,7 @@ from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT from google.cloud.bigtable.data.mutations import Mutation -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry diff --git a/google/cloud/bigtable/data/execute_query/__init__.py b/google/cloud/bigtable/data/execute_query/__init__.py index 848fc8481..ac49355ae 100644 --- a/google/cloud/bigtable/data/execute_query/__init__.py +++ b/google/cloud/bigtable/data/execute_query/__init__.py @@ -15,7 +15,7 @@ from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( ExecuteQueryIteratorAsync, ) -from google.cloud.bigtable.data.execute_query._sync.execute_query_iterator import ( +from google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator import ( ExecuteQueryIterator, ) from google.cloud.bigtable.data.execute_query.metadata import ( @@ -28,7 +28,7 @@ QueryResultRow, Struct, ) -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync CrossSync.add_mapping("ExecuteQueryIterator", ExecuteQueryIteratorAsync) CrossSync._Sync_Impl.add_mapping("ExecuteQueryIterator", ExecuteQueryIterator) diff --git a/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py similarity index 99% rename from google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py rename to google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py index 974ee3964..aa560b91a 100644 --- a/google/cloud/bigtable/data/execute_query/_sync/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py @@ -33,7 +33,7 @@ from google.cloud.bigtable_v2.types.bigtable import ( ExecuteQueryRequest as ExecuteQueryRequestPB, ) -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync if TYPE_CHECKING: from google.cloud.bigtable.data import BigtableDataClient as DataClientType diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system_autogen.py similarity index 99% rename from tests/system/data/test_system.py rename to tests/system/data/test_system_autogen.py index 381256020..40ddc1dcf 100644 --- a/tests/system/data/test_system.py +++ b/tests/system/data/test_system_autogen.py @@ -22,7 +22,7 @@ from google.api_core.exceptions import ClientError from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE from google.cloud.environment_vars import BIGTABLE_EMULATOR -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync from . import TEST_FAMILY, TEST_FAMILY_2 diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_cross_sync/test_cross_sync.py similarity index 100% rename from tests/unit/data/_sync/test_cross_sync.py rename to tests/unit/data/_cross_sync/test_cross_sync.py diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_cross_sync/test_cross_sync_decorators.py similarity index 100% rename from tests/unit/data/_sync/test_cross_sync_decorators.py rename to tests/unit/data/_cross_sync/test_cross_sync_decorators.py diff --git a/tests/unit/data/_sync/__init__.py b/tests/unit/data/_sync_autogen/__init__.py similarity index 100% rename from tests/unit/data/_sync/__init__.py rename to tests/unit/data/_sync_autogen/__init__.py diff --git a/tests/unit/data/_sync/test__mutate_rows.py b/tests/unit/data/_sync_autogen/test__mutate_rows.py similarity index 99% rename from tests/unit/data/_sync/test__mutate_rows.py rename to tests/unit/data/_sync_autogen/test__mutate_rows.py index 59c7074a8..b86bdb943 100644 --- a/tests/unit/data/_sync/test__mutate_rows.py +++ b/tests/unit/data/_sync_autogen/test__mutate_rows.py @@ -20,7 +20,7 @@ from google.rpc import status_pb2 from google.api_core.exceptions import DeadlineExceeded from google.api_core.exceptions import Forbidden -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync try: from unittest import mock diff --git a/tests/unit/data/_sync/test__read_rows.py b/tests/unit/data/_sync_autogen/test__read_rows.py similarity index 99% rename from tests/unit/data/_sync/test__read_rows.py rename to tests/unit/data/_sync_autogen/test__read_rows.py index dc6c24f5b..d762b50e1 100644 --- a/tests/unit/data/_sync/test__read_rows.py +++ b/tests/unit/data/_sync_autogen/test__read_rows.py @@ -16,7 +16,7 @@ # This file is automatically generated by CrossSync. Do not edit manually. import pytest -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync try: from unittest import mock diff --git a/tests/unit/data/_sync/test_client.py b/tests/unit/data/_sync_autogen/test_client.py similarity index 99% rename from tests/unit/data/_sync/test_client.py rename to tests/unit/data/_sync_autogen/test_client.py index 205d10b04..419c3a5b0 100644 --- a/tests/unit/data/_sync/test_client.py +++ b/tests/unit/data/_sync_autogen/test_client.py @@ -31,7 +31,7 @@ from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync from google.api_core import grpc_helpers CrossSync._Sync_Impl.add_mapping("grpc_helpers", grpc_helpers) @@ -1098,7 +1098,7 @@ def test_customizable_retryable_errors( retry_fn += "_stream" retry_fn = f"CrossSync._Sync_Impl.{retry_fn}" with mock.patch( - f"google.cloud.bigtable.data._sync.cross_sync.{retry_fn}" + f"google.cloud.bigtable.data._cross_sync.{retry_fn}" ) as retry_fn_mock: with self._make_client() as client: table = client.get_table("instance-id", "table-id") diff --git a/tests/unit/data/_sync/test_mutations_batcher.py b/tests/unit/data/_sync_autogen/test_mutations_batcher.py similarity index 99% rename from tests/unit/data/_sync/test_mutations_batcher.py rename to tests/unit/data/_sync_autogen/test_mutations_batcher.py index 2b7ba8a67..297fd553e 100644 --- a/tests/unit/data/_sync/test_mutations_batcher.py +++ b/tests/unit/data/_sync_autogen/test_mutations_batcher.py @@ -23,7 +23,7 @@ import google.api_core.retry from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.cloud.bigtable.data import TABLE_DEFAULT -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync class Test_FlowControl: @@ -467,7 +467,7 @@ def test__start_flush_timer_call_when_closed(self): @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test__flush_timer(self, num_staged): """Timer should continue to call _schedule_flush in a loop""" - from google.cloud.bigtable.data._sync.cross_sync import CrossSync + from google.cloud.bigtable.data._cross_sync import CrossSync with mock.patch.object( self._get_target_class(), "_schedule_flush" diff --git a/tests/unit/data/_sync/test_read_rows_acceptance.py b/tests/unit/data/_sync_autogen/test_read_rows_acceptance.py similarity index 99% rename from tests/unit/data/_sync/test_read_rows_acceptance.py rename to tests/unit/data/_sync_autogen/test_read_rows_acceptance.py index 61909bf99..3dc3dd00c 100644 --- a/tests/unit/data/_sync/test_read_rows_acceptance.py +++ b/tests/unit/data/_sync_autogen/test_read_rows_acceptance.py @@ -24,7 +24,7 @@ from google.cloud.bigtable.data.exceptions import InvalidChunk from google.cloud.bigtable.data.row import Row from ...v2_client.test_row_merger import ReadRowsTest, TestFile -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync class TestReadRowsAcceptance: From ab6fa6dc49e6da4dde1bbf964719f34f66aee3ab Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 17:26:41 -0700 Subject: [PATCH 312/360] regenerated files --- .../bigtable/data/_sync_autogen/_mutate_rows.py | 2 +- .../cloud/bigtable/data/_sync_autogen/_read_rows.py | 2 +- google/cloud/bigtable/data/_sync_autogen/client.py | 4 ++-- .../data/_sync_autogen/mutations_batcher.py | 2 +- tests/unit/data/_sync_autogen/test__read_rows.py | 2 +- .../data/_sync_autogen/test_mutations_batcher.py | 2 +- tests/unit/data/execute_query/_sync/__init__.py | 13 ------------- .../data/execute_query/_sync_autogen/__init__.py | 0 .../{_sync => _sync_autogen}/test_query_iterator.py | 2 +- 9 files changed, 8 insertions(+), 21 deletions(-) delete mode 100644 tests/unit/data/execute_query/_sync/__init__.py create mode 100644 tests/unit/data/execute_query/_sync_autogen/__init__.py rename tests/unit/data/execute_query/{_sync => _sync_autogen}/test_query_iterator.py (98%) diff --git a/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py b/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py index 55d4a47c5..7f488db5f 100644 --- a/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py @@ -33,7 +33,7 @@ from google.cloud.bigtable_v2.services.bigtable.client import ( BigtableClient as GapicClientType, ) - from google.cloud.bigtable.data._sync.client import Table as TableType + from google.cloud.bigtable.data._sync_autogen.client import Table as TableType class _MutateRowsOperation: diff --git a/google/cloud/bigtable/data/_sync_autogen/_read_rows.py b/google/cloud/bigtable/data/_sync_autogen/_read_rows.py index a5e817100..271aa3fa2 100644 --- a/google/cloud/bigtable/data/_sync_autogen/_read_rows.py +++ b/google/cloud/bigtable/data/_sync_autogen/_read_rows.py @@ -35,7 +35,7 @@ from google.cloud.bigtable.data._cross_sync import CrossSync if TYPE_CHECKING: - from google.cloud.bigtable.data._sync.client import Table as TableType + from google.cloud.bigtable.data._sync_autogen.client import Table as TableType class _ReadRowsOperation: diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 257b44f07..e9b1e564f 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -71,8 +71,8 @@ from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( PooledBigtableGrpcTransport as PooledTransportType, ) -from google.cloud.bigtable.data._sync.mutations_batcher import MutationsBatcher -from google.cloud.bigtable.data.execute_query._sync.execute_query_iterator import ( +from google.cloud.bigtable.data._sync_autogen.mutations_batcher import MutationsBatcher +from google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator import ( ExecuteQueryIterator, ) diff --git a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py index 714f0a946..2779ffd92 100644 --- a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry - from google.cloud.bigtable.data._sync.client import Table as TableType + from google.cloud.bigtable.data._sync_autogen.client import Table as TableType @CrossSync._Sync_Impl.add_mapping_decorator("_FlowControl") diff --git a/tests/unit/data/_sync_autogen/test__read_rows.py b/tests/unit/data/_sync_autogen/test__read_rows.py index d762b50e1..671102ce5 100644 --- a/tests/unit/data/_sync_autogen/test__read_rows.py +++ b/tests/unit/data/_sync_autogen/test__read_rows.py @@ -53,7 +53,7 @@ def test_ctor(self): expected_operation_timeout = 42 expected_request_timeout = 44 time_gen_mock = mock.Mock() - subpath = "_async" if CrossSync._Sync_Impl.is_async else "_sync" + subpath = "_async" if CrossSync._Sync_Impl.is_async else "_sync_autogen" with mock.patch( f"google.cloud.bigtable.data.{subpath}._read_rows._attempt_timeout_generator", time_gen_mock, diff --git a/tests/unit/data/_sync_autogen/test_mutations_batcher.py b/tests/unit/data/_sync_autogen/test_mutations_batcher.py index 297fd553e..1cb57f3e5 100644 --- a/tests/unit/data/_sync_autogen/test_mutations_batcher.py +++ b/tests/unit/data/_sync_autogen/test_mutations_batcher.py @@ -218,7 +218,7 @@ def test_add_to_flow_max_mutation_limits( ): """Test flow control running up against the max API limit Should submit request early, even if the flow control has room for more""" - subpath = "_async" if CrossSync._Sync_Impl.is_async else "_sync" + subpath = "_async" if CrossSync._Sync_Impl.is_async else "_sync_autogen" path = f"google.cloud.bigtable.data.{subpath}.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT" with mock.patch(path, max_limit): mutation_objs = [ diff --git a/tests/unit/data/execute_query/_sync/__init__.py b/tests/unit/data/execute_query/_sync/__init__.py deleted file mode 100644 index 6d5e14bcf..000000000 --- a/tests/unit/data/execute_query/_sync/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/unit/data/execute_query/_sync_autogen/__init__.py b/tests/unit/data/execute_query/_sync_autogen/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/data/execute_query/_sync/test_query_iterator.py b/tests/unit/data/execute_query/_sync_autogen/test_query_iterator.py similarity index 98% rename from tests/unit/data/execute_query/_sync/test_query_iterator.py rename to tests/unit/data/execute_query/_sync_autogen/test_query_iterator.py index f1142d956..77a28ea92 100644 --- a/tests/unit/data/execute_query/_sync/test_query_iterator.py +++ b/tests/unit/data/execute_query/_sync_autogen/test_query_iterator.py @@ -19,7 +19,7 @@ import concurrent.futures from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync try: from unittest import mock From 6e71ec9fc62d6a401ef05dcf82de2ecc2243e5de Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 17:28:09 -0700 Subject: [PATCH 313/360] fixed imports --- google/cloud/bigtable/data/_async/_mutate_rows.py | 2 +- google/cloud/bigtable/data/_async/_read_rows.py | 2 +- google/cloud/bigtable/data/_async/client.py | 4 ++-- google/cloud/bigtable/data/_async/mutations_batcher.py | 2 +- tests/unit/data/_async/test__read_rows.py | 2 +- tests/unit/data/_async/test_client.py | 2 +- tests/unit/data/_async/test_mutations_batcher.py | 2 +- tests/unit/data/execute_query/_async/test_query_iterator.py | 4 ++-- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index db2d0ac7f..705f03066 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -42,7 +42,7 @@ from google.cloud.bigtable_v2.services.bigtable.client import ( # type: ignore BigtableClient as GapicClientType, ) - from google.cloud.bigtable.data._sync.client import Table as TableType # type: ignore + from google.cloud.bigtable.data._sync_autogen.client import Table as TableType # type: ignore __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._mutate_rows" diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 8ba34abc6..e4ed56d0f 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -40,7 +40,7 @@ if CrossSync.is_async: from google.cloud.bigtable.data._async.client import TableAsync as TableType else: - from google.cloud.bigtable.data._sync.client import Table as TableType # type: ignore + from google.cloud.bigtable.data._sync_autogen.client import Table as TableType # type: ignore __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen._read_rows" diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index da6220b8d..cb305ae60 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -94,10 +94,10 @@ else: from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport as PooledTransportType # type: ignore - from google.cloud.bigtable.data._sync.mutations_batcher import ( # noqa: F401 + from google.cloud.bigtable.data._sync_autogen.mutations_batcher import ( # noqa: F401 MutationsBatcher, ) - from google.cloud.bigtable.data.execute_query._sync.execute_query_iterator import ( # noqa: F401 + from google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator import ( # noqa: F401 ExecuteQueryIterator, ) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 2296b9a78..1780a902e 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -40,7 +40,7 @@ if CrossSync.is_async: from google.cloud.bigtable.data._async.client import TableAsync as TableType else: - from google.cloud.bigtable.data._sync.client import Table as TableType # type: ignore + from google.cloud.bigtable.data._sync_autogen.client import Table as TableType # type: ignore __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen.mutations_batcher" diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index e460835a1..4c73e80ab 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -58,7 +58,7 @@ def test_ctor(self): expected_operation_timeout = 42 expected_request_timeout = 44 time_gen_mock = mock.Mock() - subpath = "_async" if CrossSync.is_async else "_sync" + subpath = "_async" if CrossSync.is_async else "_sync_autogen" with mock.patch( f"google.cloud.bigtable.data.{subpath}._read_rows._attempt_timeout_generator", time_gen_mock, diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 993e75882..1c91d7992 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -43,7 +43,7 @@ CrossSync.add_mapping("grpc_helpers", grpc_helpers_async) else: from google.api_core import grpc_helpers - from google.cloud.bigtable.data._sync.client import Table # noqa: F401 + from google.cloud.bigtable.data._sync_autogen.client import Table # noqa: F401 CrossSync.add_mapping("grpc_helpers", grpc_helpers) diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index 2d0f80ee3..687c0c41b 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -257,7 +257,7 @@ async def test_add_to_flow_max_mutation_limits( Test flow control running up against the max API limit Should submit request early, even if the flow control has room for more """ - subpath = "_async" if CrossSync.is_async else "_sync" + subpath = "_async" if CrossSync.is_async else "_sync_autogen" path = f"google.cloud.bigtable.data.{subpath}.mutations_batcher._MUTATE_ROWS_REQUEST_MUTATION_LIMIT" with mock.patch(path, max_limit): mutation_objs = [ diff --git a/tests/unit/data/execute_query/_async/test_query_iterator.py b/tests/unit/data/execute_query/_async/test_query_iterator.py index 9cef2dfb3..75b8a2c8e 100644 --- a/tests/unit/data/execute_query/_async/test_query_iterator.py +++ b/tests/unit/data/execute_query/_async/test_query_iterator.py @@ -18,7 +18,7 @@ from google.cloud.bigtable_v2.types.bigtable import ExecuteQueryResponse from .._testing import TYPE_INT, split_bytes_into_chunks, proto_rows_bytes -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync # try/except added for compatibility with python < 3.8 try: @@ -27,7 +27,7 @@ import mock # type: ignore -__CROSS_SYNC_OUTPUT__ = "tests.unit.data.execute_query._sync.test_query_iterator" +__CROSS_SYNC_OUTPUT__ = "tests.unit.data.execute_query._sync_autogen.test_query_iterator" @CrossSync.convert_class(sync_name="MockIterator") From 720cf81b33e0d86ae88ae6077a2f656e585ba183 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 17:28:35 -0700 Subject: [PATCH 314/360] fixed import --- tests/unit/data/_sync/test_cross_sync.py | 2 +- tests/unit/data/_sync/test_cross_sync_decorators.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index 903207694..410f59437 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -22,7 +22,7 @@ import functools import sys from google import api_core -from google.cloud.bigtable.data._sync.cross_sync.cross_sync import CrossSync, T +from google.cloud.bigtable.data._cross_sync.cross_sync import CrossSync, T # try/except added for compatibility with python < 3.8 try: diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_sync/test_cross_sync_decorators.py index febb62267..fb35a5834 100644 --- a/tests/unit/data/_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_sync/test_cross_sync_decorators.py @@ -16,8 +16,8 @@ import pytest_asyncio import ast from unittest import mock -from google.cloud.bigtable.data._sync.cross_sync.cross_sync import CrossSync -from google.cloud.bigtable.data._sync.cross_sync._decorators import ( +from google.cloud.bigtable.data._cross_sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync._decorators import ( ConvertClass, Convert, Drop, From c82d9676beb23a1aeeab69805740fb0e16463d23 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 17:29:01 -0700 Subject: [PATCH 315/360] moved tests --- tests/unit/data/{ => _cross_sync}/_sync/test_cross_sync.py | 0 .../data/{ => _cross_sync}/_sync/test_cross_sync_decorators.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/unit/data/{ => _cross_sync}/_sync/test_cross_sync.py (100%) rename tests/unit/data/{ => _cross_sync}/_sync/test_cross_sync_decorators.py (100%) diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_cross_sync/_sync/test_cross_sync.py similarity index 100% rename from tests/unit/data/_sync/test_cross_sync.py rename to tests/unit/data/_cross_sync/_sync/test_cross_sync.py diff --git a/tests/unit/data/_sync/test_cross_sync_decorators.py b/tests/unit/data/_cross_sync/_sync/test_cross_sync_decorators.py similarity index 100% rename from tests/unit/data/_sync/test_cross_sync_decorators.py rename to tests/unit/data/_cross_sync/_sync/test_cross_sync_decorators.py From 48d44a622dd67914bce1fa67c5d6cec93d3f454d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Sep 2024 17:30:08 -0700 Subject: [PATCH 316/360] renamed test file --- tests/unit/data/_cross_sync/{_sync => }/test_cross_sync.py | 0 .../data/_cross_sync/{_sync => }/test_cross_sync_decorators.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/unit/data/_cross_sync/{_sync => }/test_cross_sync.py (100%) rename tests/unit/data/_cross_sync/{_sync => }/test_cross_sync_decorators.py (100%) diff --git a/tests/unit/data/_cross_sync/_sync/test_cross_sync.py b/tests/unit/data/_cross_sync/test_cross_sync.py similarity index 100% rename from tests/unit/data/_cross_sync/_sync/test_cross_sync.py rename to tests/unit/data/_cross_sync/test_cross_sync.py diff --git a/tests/unit/data/_cross_sync/_sync/test_cross_sync_decorators.py b/tests/unit/data/_cross_sync/test_cross_sync_decorators.py similarity index 100% rename from tests/unit/data/_cross_sync/_sync/test_cross_sync_decorators.py rename to tests/unit/data/_cross_sync/test_cross_sync_decorators.py From 7f07436e637fac4ca709549ad1c302cb9d6f8930 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 27 Sep 2024 11:46:47 -0700 Subject: [PATCH 317/360] rm_aio in decorator for async --- google/cloud/bigtable/data/_async/_read_rows.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index e4ed56d0f..d0ca493a5 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -156,7 +156,7 @@ def _read_rows_attempt(self) -> CrossSync.Iterable[Row]: chunked_stream = self.chunk_stream(gapic_stream) return self.merge_rows(chunked_stream) - @CrossSync.convert + @CrossSync.convert(rm_aio=True) async def chunk_stream( self, stream: CrossSync.Awaitable[CrossSync.Iterable[ReadRowsResponsePB]] ) -> CrossSync.Iterable[ReadRowsResponsePB.CellChunk]: @@ -168,7 +168,7 @@ async def chunk_stream( Yields: ReadRowsResponsePB.CellChunk: the next chunk in the stream """ - async for resp in CrossSync.rm_aio(await stream): + async for resp in await stream: # extract proto from proto-plus wrapper resp = resp._pb @@ -210,7 +210,8 @@ async def chunk_stream( @staticmethod @CrossSync.convert( - replace_symbols={"__aiter__": "__iter__", "__anext__": "__next__"} + replace_symbols={"__aiter__": "__iter__", "__anext__": "__next__"}, + rm_aio=True, ) async def merge_rows( chunks: CrossSync.Iterable[ReadRowsResponsePB.CellChunk] | None, @@ -229,7 +230,7 @@ async def merge_rows( # For each row while True: try: - c = CrossSync.rm_aio(await it.__anext__()) + c = await it.__anext__() except CrossSync.StopIteration: # stream complete return @@ -278,7 +279,7 @@ async def merge_rows( buffer = [value] while c.value_size > 0: # throws when premature end - c = CrossSync.rm_aio(await it.__anext__()) + c = await it.__anext__() t = c.timestamp_micros cl = c.labels @@ -310,7 +311,7 @@ async def merge_rows( if c.commit_row: yield Row(row_key, cells) break - c = CrossSync.rm_aio(await it.__anext__()) + c = await it.__anext__() except _ResetRow as e: c = e.chunk if ( From e4b2e9b6d5c24e06d75fd6bd8e51309a9a8be5b2 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 27 Sep 2024 11:47:36 -0700 Subject: [PATCH 318/360] remove owlbot generation --- .cross_sync/README.md | 2 -- owlbot.py | 6 ------ 2 files changed, 8 deletions(-) diff --git a/.cross_sync/README.md b/.cross_sync/README.md index f9ecd37d2..0d8a1cf8c 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -66,8 +66,6 @@ Generation can be initiated using `nox -s generate_sync` from the root of the project. This will find all classes with the `__CROSS_SYNC_OUTPUT__ = "path/to/output"` annotation, and generate a sync version of classes marked with `@CrossSync.convert_sync` at the output path. -This will also happen automatically as part of the Owlbot CI step. - There is a unit test at `tests/unit/data/test_sync_up_to_date.py` that verifies that the generated code is up to date ## Architecture diff --git a/owlbot.py b/owlbot.py index c67ae5af8..0ec4cd61c 100644 --- a/owlbot.py +++ b/owlbot.py @@ -171,9 +171,3 @@ def insert(file, before_line, insert_line, after_line, escape=None): INSTALL_LIBRARY_FROM_SOURCE = False""") s.shell.run(["nox", "-s", "blacken"], hide_output=False) - -# ---------------------------------------------------------------------------- -# Run Cross Sync -# ---------------------------------------------------------------------------- - -s.shell.run(["nox", "-s", "generate_sync"]) From c2c28e6ef3974e28286ef90596307e1333ba73de Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 27 Sep 2024 11:51:54 -0700 Subject: [PATCH 319/360] improved error message --- tests/unit/data/test_sync_up_to_date.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py index 22dcead74..0dab4188f 100644 --- a/tests/unit/data/test_sync_up_to_date.py +++ b/tests/unit/data/test_sync_up_to_date.py @@ -48,7 +48,7 @@ def test_sync_up_to_date(sync_file): # compare by content diff = unified_diff(found_render.splitlines(), new_render.splitlines(), lineterm="") diff_str = "\n".join(diff) - assert not diff_str, f"Found differences:\n{diff_str}" + assert not diff_str, f"Found differences. Run `nox -s generate_sync` to update:\n{diff_str}" # compare by hash new_hash = hashlib.md5(new_render.encode()).hexdigest() found_hash = hashlib.md5(found_render.encode()).hexdigest() From dba89e62caf5f3983e83b6b82fd9ac5239c16d86 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 27 Sep 2024 12:57:29 -0700 Subject: [PATCH 320/360] added sanity check to sync_up_to_date --- tests/unit/data/test_sync_up_to_date.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py index 0dab4188f..7f3ef17a1 100644 --- a/tests/unit/data/test_sync_up_to_date.py +++ b/tests/unit/data/test_sync_up_to_date.py @@ -29,6 +29,21 @@ sync_files = list(convert_files_in_dir(repo_root)) +def test_found_files(): + """ + Make sure sync_test is populated with some of the files we expect to see, + to ensure that later tests are actually running. + """ + assert len(sync_files) > 0, "No sync files found" + assert len(sync_files) > 10, "Unexpectedly few sync files found" + # test for key files + outputs = [os.path.basename(f.output_path) for f in sync_files] + assert "client.py" in outputs + assert "execute_query_iterator.py" in outputs + assert "test_client.py" in outputs + assert "test_system_autogen.py" in outputs, "system tests not found" + assert "client_handler_data_sync_autogen.py" in outputs, "test proxy handler not found" + @pytest.mark.skipif( sys.version_info < (3, 9), reason="ast.unparse is only available in 3.9+" From 8cc4e075992fcc6a22f70d857ce94fb9c945f9bc Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 27 Sep 2024 15:10:44 -0700 Subject: [PATCH 321/360] use cross_sync for test proxy --- .github/workflows/conformance.yaml | 4 +- .kokoro/conformance.sh | 3 +- noxfile.py | 11 +-- test_proxy/README.md | 7 +- ...r_data.py => client_handler_data_async.py} | 14 +++- test_proxy/handlers/client_handler_legacy.py | 2 +- test_proxy/noxfile.py | 80 ------------------- test_proxy/run_tests.sh | 3 +- test_proxy/test_proxy.py | 24 +++--- 9 files changed, 37 insertions(+), 111 deletions(-) rename test_proxy/handlers/{client_handler_data.py => client_handler_data_async.py} (94%) delete mode 100644 test_proxy/noxfile.py diff --git a/.github/workflows/conformance.yaml b/.github/workflows/conformance.yaml index 68545cbec..448e1cc3a 100644 --- a/.github/workflows/conformance.yaml +++ b/.github/workflows/conformance.yaml @@ -26,9 +26,9 @@ jobs: matrix: test-version: [ "v0.0.2" ] py-version: [ 3.8 ] - client-type: [ "Async v3", "Legacy" ] + client-type: [ "async", "legacy" ] fail-fast: false - name: "${{ matrix.client-type }} Client / Python ${{ matrix.py-version }} / Test Tag ${{ matrix.test-version }}" + name: "${{ matrix.client-type }} client / python ${{ matrix.py-version }} / test tag ${{ matrix.test-version }}" steps: - uses: actions/checkout@v4 name: "Checkout python-bigtable" diff --git a/.kokoro/conformance.sh b/.kokoro/conformance.sh index 1c0b3ee0d..e85fc1394 100644 --- a/.kokoro/conformance.sh +++ b/.kokoro/conformance.sh @@ -23,7 +23,6 @@ PROXY_ARGS="" TEST_ARGS="" if [[ "${CLIENT_TYPE^^}" == "LEGACY" ]]; then echo "Using legacy client" - PROXY_ARGS="--legacy-client" # legacy client does not expose mutate_row. Disable those tests TEST_ARGS="-skip TestMutateRow_" fi @@ -31,7 +30,7 @@ fi # Build and start the proxy in a separate process PROXY_PORT=9999 pushd test_proxy -nohup python test_proxy.py --port $PROXY_PORT $PROXY_ARGS & +nohup python test_proxy.py --port $PROXY_PORT --client_type=$CLIENT_TYPE & proxyPID=$! popd diff --git a/noxfile.py b/noxfile.py index 1931ac3b1..a3964b0f5 100644 --- a/noxfile.py +++ b/noxfile.py @@ -288,9 +288,8 @@ def system_emulated(session): @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) -def conformance(session): - TEST_REPO_URL = "https://github.com/googleapis/cloud-bigtable-clients-test.git" - CLONE_REPO_DIR = "cloud-bigtable-clients-test" +@nox.parametrize("client_type", ["async"]) +def conformance(session, client_type): # install dependencies constraints_path = str( CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" @@ -298,11 +297,7 @@ def conformance(session): install_unittest_dependencies(session, "-c", constraints_path) with session.chdir("test_proxy"): # download the conformance test suite - clone_dir = os.path.join(CURRENT_DIRECTORY, CLONE_REPO_DIR) - if not os.path.exists(clone_dir): - print("downloading copy of test repo") - session.run("git", "clone", TEST_REPO_URL, CLONE_REPO_DIR, external=True) - session.run("bash", "-e", "run_tests.sh", external=True) + session.run("bash", "-e", "run_tests.sh", external=True, env={"CLIENT_TYPE": client_type}) @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) diff --git a/test_proxy/README.md b/test_proxy/README.md index 08741fd5d..266fba7cd 100644 --- a/test_proxy/README.md +++ b/test_proxy/README.md @@ -8,7 +8,7 @@ You can run the conformance tests in a single line by calling `nox -s conformanc ``` -cd python-bigtable/test_proxy +cd python-bigtable nox -s conformance ``` @@ -30,10 +30,11 @@ cd python-bigtable/test_proxy python test_proxy.py --port 8080 ``` -You can run the test proxy against the previous `v2` client by running it with the `--legacy-client` flag: +By default, the test_proxy targets the async client. You can change this by passing in the `--client_type` flag. +Valid options are `async` and `legacy`. ``` -python test_proxy.py --legacy-client +python test_proxy.py --client_type=legacy ``` ### Run the test cases diff --git a/test_proxy/handlers/client_handler_data.py b/test_proxy/handlers/client_handler_data_async.py similarity index 94% rename from test_proxy/handlers/client_handler_data.py rename to test_proxy/handlers/client_handler_data_async.py index 43ff5d634..387909169 100644 --- a/test_proxy/handlers/client_handler_data.py +++ b/test_proxy/handlers/client_handler_data_async.py @@ -18,8 +18,12 @@ from google.cloud.environment_vars import BIGTABLE_EMULATOR from google.cloud.bigtable.data import BigtableDataClientAsync +from google.cloud.bigtable.data._cross_sync import CrossSync +__CROSS_SYNC_OUTPUT__ = "test_proxy.handlers.client_handler_data_sync_autogen" + +@CrossSync.convert(rm_aio=True) def error_safe(func): """ Catch and pass errors back to the grpc_server_process @@ -68,7 +72,8 @@ def encode_exception(exc): return result -class TestProxyClientHandler: +@CrossSync.convert_class("TestProxyClientHandler") +class TestProxyClientHandlerAsync: """ Implements the same methods as the grpc server, but handles the client library side of the request. @@ -100,6 +105,7 @@ def close(self): self.closed = True @error_safe + @CrossSync.convert(rm_aio=True) async def ReadRows(self, request, **kwargs): table_id = request.pop("table_name").split("/")[-1] app_profile_id = self.app_profile_id or request.get("app_profile_id", None) @@ -111,6 +117,7 @@ async def ReadRows(self, request, **kwargs): return serialized_response @error_safe + @CrossSync.convert(rm_aio=True) async def ReadRow(self, row_key, **kwargs): table_id = kwargs.pop("table_name").split("/")[-1] app_profile_id = self.app_profile_id or kwargs.get("app_profile_id", None) @@ -124,6 +131,7 @@ async def ReadRow(self, row_key, **kwargs): return "None" @error_safe + @CrossSync.convert(rm_aio=True) async def MutateRow(self, request, **kwargs): from google.cloud.bigtable.data.mutations import Mutation table_id = request["table_name"].split("/")[-1] @@ -136,6 +144,7 @@ async def MutateRow(self, request, **kwargs): return "OK" @error_safe + @CrossSync.convert(rm_aio=True) async def BulkMutateRows(self, request, **kwargs): from google.cloud.bigtable.data.mutations import RowMutationEntry table_id = request["table_name"].split("/")[-1] @@ -147,6 +156,7 @@ async def BulkMutateRows(self, request, **kwargs): return "OK" @error_safe + @CrossSync.convert(rm_aio=True) async def CheckAndMutateRow(self, request, **kwargs): from google.cloud.bigtable.data.mutations import Mutation, SetCell table_id = request["table_name"].split("/")[-1] @@ -181,6 +191,7 @@ async def CheckAndMutateRow(self, request, **kwargs): return result @error_safe + @CrossSync.convert(rm_aio=True) async def ReadModifyWriteRow(self, request, **kwargs): from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule @@ -205,6 +216,7 @@ async def ReadModifyWriteRow(self, request, **kwargs): return "None" @error_safe + @CrossSync.convert(rm_aio=True) async def SampleRowKeys(self, request, **kwargs): table_id = request["table_name"].split("/")[-1] app_profile_id = self.app_profile_id or request.get("app_profile_id", None) diff --git a/test_proxy/handlers/client_handler_legacy.py b/test_proxy/handlers/client_handler_legacy.py index 400f618b5..8a805509b 100644 --- a/test_proxy/handlers/client_handler_legacy.py +++ b/test_proxy/handlers/client_handler_legacy.py @@ -19,7 +19,7 @@ from google.cloud.environment_vars import BIGTABLE_EMULATOR from google.cloud.bigtable.client import Client -import client_handler_data as client_handler +import client_handler_data_async as client_handler import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) diff --git a/test_proxy/noxfile.py b/test_proxy/noxfile.py deleted file mode 100644 index bebf247b7..000000000 --- a/test_proxy/noxfile.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -import os -import pathlib -import re -from colorlog.escape_codes import parse_colors - -import nox - - -DEFAULT_PYTHON_VERSION = "3.10" - -PROXY_SERVER_PORT=os.environ.get("PROXY_SERVER_PORT", "50055") -PROXY_CLIENT_VERSION=os.environ.get("PROXY_CLIENT_VERSION", None) - -CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() -REPO_ROOT_DIRECTORY = CURRENT_DIRECTORY.parent - -nox.options.sessions = ["run_proxy", "conformance_tests"] - -TEST_REPO_URL = "https://github.com/googleapis/cloud-bigtable-clients-test.git" -CLONE_REPO_DIR = "cloud-bigtable-clients-test" - -# Error if a python version is missing -nox.options.error_on_missing_interpreters = True - - -def default(session): - """ - if nox is run directly, run the test_proxy session - """ - test_proxy(session) - - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def conformance_tests(session): - """ - download and run the conformance test suite against the test proxy - """ - import subprocess - import time - # download the conformance test suite - clone_dir = os.path.join(CURRENT_DIRECTORY, CLONE_REPO_DIR) - if not os.path.exists(clone_dir): - print("downloading copy of test repo") - session.run("git", "clone", TEST_REPO_URL, CLONE_REPO_DIR) - # start tests - with session.chdir(f"{clone_dir}/tests"): - session.run("go", "test", "-v", f"-proxy_addr=:{PROXY_SERVER_PORT}") - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def test_proxy(session): - """Start up the test proxy""" - # Install all dependencies, then install this package into the - # virtualenv's dist-packages. - # session.install( - # "grpcio", - # ) - if PROXY_CLIENT_VERSION is not None: - # install released version of the library - session.install(f"python-bigtable=={PROXY_CLIENT_VERSION}") - else: - # install the library from the source - session.install("-e", str(REPO_ROOT_DIRECTORY)) - session.install("-e", str(REPO_ROOT_DIRECTORY / "python-api-core")) - - session.run("python", "test_proxy.py", "--port", PROXY_SERVER_PORT, *session.posargs,) diff --git a/test_proxy/run_tests.sh b/test_proxy/run_tests.sh index 15b146b03..c2e9c6312 100755 --- a/test_proxy/run_tests.sh +++ b/test_proxy/run_tests.sh @@ -35,7 +35,8 @@ if [ ! -d "cloud-bigtable-clients-test" ]; then fi # start proxy -python test_proxy.py --port $PROXY_SERVER_PORT & +echo "starting with client type: $CLIENT_TYPE" +python test_proxy.py --port $PROXY_SERVER_PORT --client_type $CLIENT_TYPE & PROXY_PID=$! function finish { kill $PROXY_PID diff --git a/test_proxy/test_proxy.py b/test_proxy/test_proxy.py index a0cf2f1f0..9e03f1e5c 100644 --- a/test_proxy/test_proxy.py +++ b/test_proxy/test_proxy.py @@ -55,7 +55,7 @@ def grpc_server_process(request_q, queue_pool, port=50055): server.wait_for_termination() -async def client_handler_process_async(request_q, queue_pool, use_legacy_client=False): +async def client_handler_process_async(request_q, queue_pool, client_type="async"): """ Defines a process that recives Bigtable requests from a grpc_server_process, and runs the request using a client library instance @@ -64,8 +64,7 @@ async def client_handler_process_async(request_q, queue_pool, use_legacy_client= import re import asyncio import warnings - import client_handler_data - import client_handler_legacy + import client_handler_data_async warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*Bigtable emulator.*") def camel_to_snake(str): @@ -98,9 +97,7 @@ def format_dict(input_obj): return input_obj # Listen to requests from grpc server process - print_msg = "client_handler_process started" - if use_legacy_client: - print_msg += ", using legacy client" + print_msg = f"client_handler_process started with client_type={client_type}" print(print_msg) client_map = {} background_tasks = set() @@ -114,10 +111,11 @@ def format_dict(input_obj): client = client_map.get(client_id, None) # handle special cases for client creation and deletion if fn_name == "CreateClient": - if use_legacy_client: + if client_type == "legacy": + import client_handler_legacy client = client_handler_legacy.LegacyTestProxyClientHandler(**json_data) else: - client = client_handler_data.TestProxyClientHandler(**json_data) + client = client_handler_data_async.TestProxyClientHandlerAsync(**json_data) client_map[client_id] = client out_q.put(True) elif client is None: @@ -142,21 +140,21 @@ async def _run_fn(out_q, fn, **kwargs): await asyncio.sleep(0.01) -def client_handler_process(request_q, queue_pool, legacy_client=False): +def client_handler_process(request_q, queue_pool, client_type="async"): """ Sync entrypoint for client_handler_process_async """ import asyncio - asyncio.run(client_handler_process_async(request_q, queue_pool, legacy_client)) + asyncio.run(client_handler_process_async(request_q, queue_pool, client_type)) p = argparse.ArgumentParser() p.add_argument("--port", dest='port', default="50055") -p.add_argument('--legacy-client', dest='use_legacy', action='store_true', default=False) +p.add_argument("--client_type", dest='client_type', default="async", choices=["async", "legacy"]) if __name__ == "__main__": port = p.parse_args().port - use_legacy_client = p.parse_args().use_legacy + client_type = p.parse_args().client_type # start and run both processes # larger pools support more concurrent requests @@ -176,7 +174,7 @@ def client_handler_process(request_q, queue_pool, legacy_client=False): ), ) proxy.start() - client_handler_process(request_q, response_queue_pool, use_legacy_client) + client_handler_process(request_q, response_queue_pool, client_type) proxy.join() else: # run proxy in forground and client in background From a8f371cb13914570ad67977a5b759a2b107b19fc Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 27 Sep 2024 15:16:01 -0700 Subject: [PATCH 322/360] updated conformance tests for sync client --- .github/workflows/conformance.yaml | 2 +- noxfile.py | 2 +- test_proxy/README.md | 2 +- .../client_handler_data_sync_autogen.py | 227 ++++++++++++++++++ test_proxy/test_proxy.py | 5 +- 5 files changed, 234 insertions(+), 4 deletions(-) create mode 100644 test_proxy/handlers/client_handler_data_sync_autogen.py diff --git a/.github/workflows/conformance.yaml b/.github/workflows/conformance.yaml index 448e1cc3a..e907d5a92 100644 --- a/.github/workflows/conformance.yaml +++ b/.github/workflows/conformance.yaml @@ -26,7 +26,7 @@ jobs: matrix: test-version: [ "v0.0.2" ] py-version: [ 3.8 ] - client-type: [ "async", "legacy" ] + client-type: [ "async", "sync", "legacy" ] fail-fast: false name: "${{ matrix.client-type }} client / python ${{ matrix.py-version }} / test tag ${{ matrix.test-version }}" steps: diff --git a/noxfile.py b/noxfile.py index d5924de10..d0e71bd9c 100644 --- a/noxfile.py +++ b/noxfile.py @@ -290,7 +290,7 @@ def system_emulated(session): @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) -@nox.parametrize("client_type", ["async"]) +@nox.parametrize("client_type", ["async", "sync"]) def conformance(session, client_type): # install dependencies constraints_path = str( diff --git a/test_proxy/README.md b/test_proxy/README.md index 266fba7cd..5c87c729a 100644 --- a/test_proxy/README.md +++ b/test_proxy/README.md @@ -31,7 +31,7 @@ python test_proxy.py --port 8080 ``` By default, the test_proxy targets the async client. You can change this by passing in the `--client_type` flag. -Valid options are `async` and `legacy`. +Valid options are `async`, `sync`, and `legacy`. ``` python test_proxy.py --client_type=legacy diff --git a/test_proxy/handlers/client_handler_data_sync_autogen.py b/test_proxy/handlers/client_handler_data_sync_autogen.py new file mode 100644 index 000000000..54e17c57d --- /dev/null +++ b/test_proxy/handlers/client_handler_data_sync_autogen.py @@ -0,0 +1,227 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is automatically generated by CrossSync. Do not edit manually. + +""" +This module contains the client handler process for proxy_server.py. +""" +import os +from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.cloud.bigtable.data import BigtableDataClientAsync + + +def error_safe(func): + """Catch and pass errors back to the grpc_server_process + Also check if client is closed before processing requests""" + + def wrapper(self, *args, **kwargs): + try: + if self.closed: + raise RuntimeError("client is closed") + return func(self, *args, **kwargs) + except (Exception, NotImplementedError) as e: + return encode_exception(e) + + return wrapper + + +def encode_exception(exc): + """Encode an exception or chain of exceptions to pass back to grpc_handler""" + from google.api_core.exceptions import GoogleAPICallError + + error_msg = f"{type(exc).__name__}: {exc}" + result = {"error": error_msg} + if exc.__cause__: + result["cause"] = encode_exception(exc.__cause__) + if hasattr(exc, "exceptions"): + result["subexceptions"] = [encode_exception(e) for e in exc.exceptions] + if hasattr(exc, "index"): + result["index"] = exc.index + if isinstance(exc, GoogleAPICallError): + if exc.grpc_status_code is not None: + result["code"] = exc.grpc_status_code.value[0] + elif exc.code is not None: + result["code"] = int(exc.code) + else: + result["code"] = -1 + elif result.get("cause", {}).get("code", None): + result["code"] = result["cause"]["code"] + elif result.get("subexceptions", None): + for subexc in result["subexceptions"]: + if subexc.get("code", None): + result["code"] = subexc["code"] + return result + + +class TestProxyClientHandler: + """ + Implements the same methods as the grpc server, but handles the client + library side of the request. + + Requests received in TestProxyGrpcServer are converted to a dictionary, + and supplied to the TestProxyClientHandler methods as kwargs. + The client response is then returned back to the TestProxyGrpcServer + """ + + def __init__( + self, + data_target=None, + project_id=None, + instance_id=None, + app_profile_id=None, + per_operation_timeout=None, + **kwargs, + ): + self.closed = False + os.environ[BIGTABLE_EMULATOR] = data_target + self.client = BigtableDataClientAsync(project=project_id) + self.instance_id = instance_id + self.app_profile_id = app_profile_id + self.per_operation_timeout = per_operation_timeout + + def close(self): + self.closed = True + + @error_safe + def ReadRows(self, request, **kwargs): + table_id = request.pop("table_name").split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + result_list = table.read_rows(request, **kwargs) + serialized_response = [row._to_dict() for row in result_list] + return serialized_response + + @error_safe + def ReadRow(self, row_key, **kwargs): + table_id = kwargs.pop("table_name").split("/")[-1] + app_profile_id = self.app_profile_id or kwargs.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + result_row = table.read_row(row_key, **kwargs) + if result_row: + return result_row._to_dict() + else: + return "None" + + @error_safe + def MutateRow(self, request, **kwargs): + from google.cloud.bigtable.data.mutations import Mutation + + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + row_key = request["row_key"] + mutations = [Mutation._from_dict(d) for d in request["mutations"]] + table.mutate_row(row_key, mutations, **kwargs) + return "OK" + + @error_safe + def BulkMutateRows(self, request, **kwargs): + from google.cloud.bigtable.data.mutations import RowMutationEntry + + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + entry_list = [ + RowMutationEntry._from_dict(entry) for entry in request["entries"] + ] + table.bulk_mutate_rows(entry_list, **kwargs) + return "OK" + + @error_safe + def CheckAndMutateRow(self, request, **kwargs): + from google.cloud.bigtable.data.mutations import Mutation, SetCell + + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + row_key = request["row_key"] + true_mutations = [] + for mut_dict in request.get("true_mutations", []): + try: + true_mutations.append(Mutation._from_dict(mut_dict)) + except ValueError: + mutation = SetCell("", "", "", 0) + true_mutations.append(mutation) + false_mutations = [] + for mut_dict in request.get("false_mutations", []): + try: + false_mutations.append(Mutation._from_dict(mut_dict)) + except ValueError: + false_mutations.append(SetCell("", "", "", 0)) + predicate_filter = request.get("predicate_filter", None) + result = table.check_and_mutate_row( + row_key, + predicate_filter, + true_case_mutations=true_mutations, + false_case_mutations=false_mutations, + **kwargs, + ) + return result + + @error_safe + def ReadModifyWriteRow(self, request, **kwargs): + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + row_key = request["row_key"] + rules = [] + for rule_dict in request.get("rules", []): + qualifier = rule_dict["column_qualifier"] + if "append_value" in rule_dict: + new_rule = AppendValueRule( + rule_dict["family_name"], qualifier, rule_dict["append_value"] + ) + else: + new_rule = IncrementRule( + rule_dict["family_name"], qualifier, rule_dict["increment_amount"] + ) + rules.append(new_rule) + result = table.read_modify_write_row(row_key, rules, **kwargs) + if result: + return result._to_dict() + else: + return "None" + + @error_safe + def SampleRowKeys(self, request, **kwargs): + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + result = table.sample_row_keys(**kwargs) + return result diff --git a/test_proxy/test_proxy.py b/test_proxy/test_proxy.py index 9e03f1e5c..793500768 100644 --- a/test_proxy/test_proxy.py +++ b/test_proxy/test_proxy.py @@ -114,6 +114,9 @@ def format_dict(input_obj): if client_type == "legacy": import client_handler_legacy client = client_handler_legacy.LegacyTestProxyClientHandler(**json_data) + elif client_type == "sync": + import client_handler_data_sync_autogen + client = client_handler_data_sync_autogen.TestProxyClientHandler(**json_data) else: client = client_handler_data_async.TestProxyClientHandlerAsync(**json_data) client_map[client_id] = client @@ -150,7 +153,7 @@ def client_handler_process(request_q, queue_pool, client_type="async"): p = argparse.ArgumentParser() p.add_argument("--port", dest='port', default="50055") -p.add_argument("--client_type", dest='client_type', default="async", choices=["async", "legacy"]) +p.add_argument("--client_type", dest='client_type', default="async", choices=["async", "sync", "legacy"]) if __name__ == "__main__": port = p.parse_args().port From 20bb81d249b7a64e70e256a0f1b9ebc2c59de5d9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 27 Sep 2024 15:42:25 -0700 Subject: [PATCH 323/360] fixed test proxy issues --- .../handlers/client_handler_data_async.py | 31 +++++++++---------- test_proxy/handlers/client_handler_legacy.py | 2 +- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/test_proxy/handlers/client_handler_data_async.py b/test_proxy/handlers/client_handler_data_async.py index 387909169..7f6cc413f 100644 --- a/test_proxy/handlers/client_handler_data_async.py +++ b/test_proxy/handlers/client_handler_data_async.py @@ -20,10 +20,13 @@ from google.cloud.bigtable.data import BigtableDataClientAsync from google.cloud.bigtable.data._cross_sync import CrossSync +if not CrossSync.is_async: + from client_handler_data_async import error_safe + __CROSS_SYNC_OUTPUT__ = "test_proxy.handlers.client_handler_data_sync_autogen" -@CrossSync.convert(rm_aio=True) +@CrossSync.drop def error_safe(func): """ Catch and pass errors back to the grpc_server_process @@ -41,6 +44,7 @@ async def wrapper(self, *args, **kwargs): return wrapper +@CrossSync.drop def encode_exception(exc): """ Encode an exception or chain of exceptions to pass back to grpc_handler @@ -95,7 +99,7 @@ def __init__( self.closed = False # use emulator os.environ[BIGTABLE_EMULATOR] = data_target - self.client = BigtableDataClientAsync(project=project_id) + self.client = CrossSync.DataClient(project=project_id) self.instance_id = instance_id self.app_profile_id = app_profile_id self.per_operation_timeout = per_operation_timeout @@ -105,25 +109,23 @@ def close(self): self.closed = True @error_safe - @CrossSync.convert(rm_aio=True) async def ReadRows(self, request, **kwargs): table_id = request.pop("table_name").split("/")[-1] app_profile_id = self.app_profile_id or request.get("app_profile_id", None) table = self.client.get_table(self.instance_id, table_id, app_profile_id) kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 - result_list = await table.read_rows(request, **kwargs) + result_list = CrossSync.rm_aio(await table.read_rows(request, **kwargs)) # pack results back into protobuf-parsable format serialized_response = [row._to_dict() for row in result_list] return serialized_response @error_safe - @CrossSync.convert(rm_aio=True) async def ReadRow(self, row_key, **kwargs): table_id = kwargs.pop("table_name").split("/")[-1] app_profile_id = self.app_profile_id or kwargs.get("app_profile_id", None) table = self.client.get_table(self.instance_id, table_id, app_profile_id) kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 - result_row = await table.read_row(row_key, **kwargs) + result_row = CrossSync.rm_aio(await table.read_row(row_key, **kwargs)) # pack results back into protobuf-parsable format if result_row: return result_row._to_dict() @@ -131,7 +133,6 @@ async def ReadRow(self, row_key, **kwargs): return "None" @error_safe - @CrossSync.convert(rm_aio=True) async def MutateRow(self, request, **kwargs): from google.cloud.bigtable.data.mutations import Mutation table_id = request["table_name"].split("/")[-1] @@ -140,11 +141,10 @@ async def MutateRow(self, request, **kwargs): kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 row_key = request["row_key"] mutations = [Mutation._from_dict(d) for d in request["mutations"]] - await table.mutate_row(row_key, mutations, **kwargs) + CrossSync.rm_aio(await table.mutate_row(row_key, mutations, **kwargs)) return "OK" @error_safe - @CrossSync.convert(rm_aio=True) async def BulkMutateRows(self, request, **kwargs): from google.cloud.bigtable.data.mutations import RowMutationEntry table_id = request["table_name"].split("/")[-1] @@ -152,11 +152,10 @@ async def BulkMutateRows(self, request, **kwargs): table = self.client.get_table(self.instance_id, table_id, app_profile_id) kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 entry_list = [RowMutationEntry._from_dict(entry) for entry in request["entries"]] - await table.bulk_mutate_rows(entry_list, **kwargs) + CrossSync.rm_aio(await table.bulk_mutate_rows(entry_list, **kwargs)) return "OK" @error_safe - @CrossSync.convert(rm_aio=True) async def CheckAndMutateRow(self, request, **kwargs): from google.cloud.bigtable.data.mutations import Mutation, SetCell table_id = request["table_name"].split("/")[-1] @@ -181,17 +180,16 @@ async def CheckAndMutateRow(self, request, **kwargs): # invalid mutation type. Conformance test may be sending generic empty request false_mutations.append(SetCell("", "", "", 0)) predicate_filter = request.get("predicate_filter", None) - result = await table.check_and_mutate_row( + result = CrossSync.rm_aio(await table.check_and_mutate_row( row_key, predicate_filter, true_case_mutations=true_mutations, false_case_mutations=false_mutations, **kwargs, - ) + )) return result @error_safe - @CrossSync.convert(rm_aio=True) async def ReadModifyWriteRow(self, request, **kwargs): from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule @@ -208,7 +206,7 @@ async def ReadModifyWriteRow(self, request, **kwargs): else: new_rule = IncrementRule(rule_dict["family_name"], qualifier, rule_dict["increment_amount"]) rules.append(new_rule) - result = await table.read_modify_write_row(row_key, rules, **kwargs) + result = CrossSync.rm_aio(await table.read_modify_write_row(row_key, rules, **kwargs)) # pack results back into protobuf-parsable format if result: return result._to_dict() @@ -216,11 +214,10 @@ async def ReadModifyWriteRow(self, request, **kwargs): return "None" @error_safe - @CrossSync.convert(rm_aio=True) async def SampleRowKeys(self, request, **kwargs): table_id = request["table_name"].split("/")[-1] app_profile_id = self.app_profile_id or request.get("app_profile_id", None) table = self.client.get_table(self.instance_id, table_id, app_profile_id) kwargs["operation_timeout"] = kwargs.get("operation_timeout", self.per_operation_timeout) or 20 - result = await table.sample_row_keys(**kwargs) + result = CrossSync.rm_aio(await table.sample_row_keys(**kwargs)) return result diff --git a/test_proxy/handlers/client_handler_legacy.py b/test_proxy/handlers/client_handler_legacy.py index 8a805509b..63fe357b0 100644 --- a/test_proxy/handlers/client_handler_legacy.py +++ b/test_proxy/handlers/client_handler_legacy.py @@ -25,7 +25,7 @@ warnings.filterwarnings("ignore", category=DeprecationWarning) -class LegacyTestProxyClientHandler(client_handler.TestProxyClientHandler): +class LegacyTestProxyClientHandler(client_handler.TestProxyClientHandlerAsync): def __init__( self, From 61e3339f095922cccd884ba2fd53cb642fddc738 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 27 Sep 2024 15:43:21 -0700 Subject: [PATCH 324/360] updated test proxy files --- .../client_handler_data_sync_autogen.py | 66 ++++--------------- 1 file changed, 12 insertions(+), 54 deletions(-) diff --git a/test_proxy/handlers/client_handler_data_sync_autogen.py b/test_proxy/handlers/client_handler_data_sync_autogen.py index 54e17c57d..52ddec6fd 100644 --- a/test_proxy/handlers/client_handler_data_sync_autogen.py +++ b/test_proxy/handlers/client_handler_data_sync_autogen.py @@ -19,50 +19,8 @@ """ import os from google.cloud.environment_vars import BIGTABLE_EMULATOR -from google.cloud.bigtable.data import BigtableDataClientAsync - - -def error_safe(func): - """Catch and pass errors back to the grpc_server_process - Also check if client is closed before processing requests""" - - def wrapper(self, *args, **kwargs): - try: - if self.closed: - raise RuntimeError("client is closed") - return func(self, *args, **kwargs) - except (Exception, NotImplementedError) as e: - return encode_exception(e) - - return wrapper - - -def encode_exception(exc): - """Encode an exception or chain of exceptions to pass back to grpc_handler""" - from google.api_core.exceptions import GoogleAPICallError - - error_msg = f"{type(exc).__name__}: {exc}" - result = {"error": error_msg} - if exc.__cause__: - result["cause"] = encode_exception(exc.__cause__) - if hasattr(exc, "exceptions"): - result["subexceptions"] = [encode_exception(e) for e in exc.exceptions] - if hasattr(exc, "index"): - result["index"] = exc.index - if isinstance(exc, GoogleAPICallError): - if exc.grpc_status_code is not None: - result["code"] = exc.grpc_status_code.value[0] - elif exc.code is not None: - result["code"] = int(exc.code) - else: - result["code"] = -1 - elif result.get("cause", {}).get("code", None): - result["code"] = result["cause"]["code"] - elif result.get("subexceptions", None): - for subexc in result["subexceptions"]: - if subexc.get("code", None): - result["code"] = subexc["code"] - return result +from google.cloud.bigtable.data._cross_sync import CrossSync +from client_handler_data_async import error_safe class TestProxyClientHandler: @@ -82,11 +40,11 @@ def __init__( instance_id=None, app_profile_id=None, per_operation_timeout=None, - **kwargs, + **kwargs ): self.closed = False os.environ[BIGTABLE_EMULATOR] = data_target - self.client = BigtableDataClientAsync(project=project_id) + self.client = CrossSync._Sync_Impl.DataClient(project=project_id) self.instance_id = instance_id self.app_profile_id = app_profile_id self.per_operation_timeout = per_operation_timeout @@ -95,7 +53,7 @@ def close(self): self.closed = True @error_safe - def ReadRows(self, request, **kwargs): + async def ReadRows(self, request, **kwargs): table_id = request.pop("table_name").split("/")[-1] app_profile_id = self.app_profile_id or request.get("app_profile_id", None) table = self.client.get_table(self.instance_id, table_id, app_profile_id) @@ -107,7 +65,7 @@ def ReadRows(self, request, **kwargs): return serialized_response @error_safe - def ReadRow(self, row_key, **kwargs): + async def ReadRow(self, row_key, **kwargs): table_id = kwargs.pop("table_name").split("/")[-1] app_profile_id = self.app_profile_id or kwargs.get("app_profile_id", None) table = self.client.get_table(self.instance_id, table_id, app_profile_id) @@ -121,7 +79,7 @@ def ReadRow(self, row_key, **kwargs): return "None" @error_safe - def MutateRow(self, request, **kwargs): + async def MutateRow(self, request, **kwargs): from google.cloud.bigtable.data.mutations import Mutation table_id = request["table_name"].split("/")[-1] @@ -136,7 +94,7 @@ def MutateRow(self, request, **kwargs): return "OK" @error_safe - def BulkMutateRows(self, request, **kwargs): + async def BulkMutateRows(self, request, **kwargs): from google.cloud.bigtable.data.mutations import RowMutationEntry table_id = request["table_name"].split("/")[-1] @@ -152,7 +110,7 @@ def BulkMutateRows(self, request, **kwargs): return "OK" @error_safe - def CheckAndMutateRow(self, request, **kwargs): + async def CheckAndMutateRow(self, request, **kwargs): from google.cloud.bigtable.data.mutations import Mutation, SetCell table_id = request["table_name"].split("/")[-1] @@ -181,12 +139,12 @@ def CheckAndMutateRow(self, request, **kwargs): predicate_filter, true_case_mutations=true_mutations, false_case_mutations=false_mutations, - **kwargs, + **kwargs ) return result @error_safe - def ReadModifyWriteRow(self, request, **kwargs): + async def ReadModifyWriteRow(self, request, **kwargs): from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule @@ -216,7 +174,7 @@ def ReadModifyWriteRow(self, request, **kwargs): return "None" @error_safe - def SampleRowKeys(self, request, **kwargs): + async def SampleRowKeys(self, request, **kwargs): table_id = request["table_name"].split("/")[-1] app_profile_id = self.app_profile_id or request.get("app_profile_id", None) table = self.client.get_table(self.instance_id, table_id, app_profile_id) From 601edf7036a113d077098135e4e078ceda1149b3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 27 Sep 2024 16:31:17 -0700 Subject: [PATCH 325/360] skip multistream conformance tests in sync surface --- .kokoro/conformance.sh | 8 +++++++- noxfile.py | 2 +- test_proxy/run_tests.sh | 17 +++++++++++++++-- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/.kokoro/conformance.sh b/.kokoro/conformance.sh index e85fc1394..f2650c67f 100644 --- a/.kokoro/conformance.sh +++ b/.kokoro/conformance.sh @@ -21,12 +21,18 @@ cd $(dirname $0)/.. PROXY_ARGS="" TEST_ARGS="" -if [[ "${CLIENT_TYPE^^}" == "LEGACY" ]]; then +if [[ $CLIENT_TYPE == "legacy" ]]; then echo "Using legacy client" # legacy client does not expose mutate_row. Disable those tests TEST_ARGS="-skip TestMutateRow_" fi +if [[ $CLIENT_TYPE != "async" ]]; then + echo "Using legacy client" + # sync and legacy client do not support concurrent streams + TEST_ARGS="$TEST_ARGS -skip _Generic_MultiStream " +fi + # Build and start the proxy in a separate process PROXY_PORT=9999 pushd test_proxy diff --git a/noxfile.py b/noxfile.py index d0e71bd9c..9b874eb6c 100644 --- a/noxfile.py +++ b/noxfile.py @@ -290,7 +290,7 @@ def system_emulated(session): @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) -@nox.parametrize("client_type", ["async", "sync"]) +@nox.parametrize("client_type", ["async", "sync", "legacy"]) def conformance(session, client_type): # install dependencies constraints_path = str( diff --git a/test_proxy/run_tests.sh b/test_proxy/run_tests.sh index c2e9c6312..b6f1291a6 100755 --- a/test_proxy/run_tests.sh +++ b/test_proxy/run_tests.sh @@ -27,7 +27,7 @@ fi SCRIPT_DIR=$(realpath $(dirname "$0")) cd $SCRIPT_DIR -export PROXY_SERVER_PORT=50055 +export PROXY_SERVER_PORT=$(shuf -i 50000-60000 -n 1) # download test suite if [ ! -d "cloud-bigtable-clients-test" ]; then @@ -43,6 +43,19 @@ function finish { } trap finish EXIT +if [[ $CLIENT_TYPE == "legacy" ]]; then + echo "Using legacy client" + # legacy client does not expose mutate_row. Disable those tests + TEST_ARGS="-skip TestMutateRow_" +fi + +if [[ $CLIENT_TYPE != "async" ]]; then + echo "Using legacy client" + # sync and legacy client do not support concurrent streams + TEST_ARGS="$TEST_ARGS -skip _Generic_MultiStream " +fi + # run tests pushd cloud-bigtable-clients-test/tests -go test -v -proxy_addr=:$PROXY_SERVER_PORT +echo "Running with $TEST_ARGS" +go test -v -proxy_addr=:$PROXY_SERVER_PORT $TEST_ARGS From 86651db5874054655bfadcc53979fc113fbd782b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 27 Sep 2024 16:47:30 -0700 Subject: [PATCH 326/360] add skips into github actions file --- .github/workflows/conformance.yaml | 9 +++++++++ .kokoro/conformance.sh | 16 +--------------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/.github/workflows/conformance.yaml b/.github/workflows/conformance.yaml index e907d5a92..e39c7b5a2 100644 --- a/.github/workflows/conformance.yaml +++ b/.github/workflows/conformance.yaml @@ -27,6 +27,13 @@ jobs: test-version: [ "v0.0.2" ] py-version: [ 3.8 ] client-type: [ "async", "sync", "legacy" ] + include: + - client-type: "sync" + # sync and legacy client do not support concurrent streams + test_args: "-skip _Generic_MultiStream" + - client-type: "legacy" + # legacy client does not expose mutate_row. Disable those tests + test_argss: "-skip _Generic_MultiStream -skip TestMutateRow_" fail-fast: false name: "${{ matrix.client-type }} client / python ${{ matrix.py-version }} / test tag ${{ matrix.test-version }}" steps: @@ -53,4 +60,6 @@ jobs: env: CLIENT_TYPE: ${{ matrix.client-type }} PYTHONUNBUFFERED: 1 + TEST_ARGS: ${{ matrix.test_args }} + PROXY_PORT: 9999 diff --git a/.kokoro/conformance.sh b/.kokoro/conformance.sh index f2650c67f..fd585142e 100644 --- a/.kokoro/conformance.sh +++ b/.kokoro/conformance.sh @@ -19,22 +19,7 @@ set -eo pipefail ## cd to the parent directory, i.e. the root of the git repo cd $(dirname $0)/.. -PROXY_ARGS="" -TEST_ARGS="" -if [[ $CLIENT_TYPE == "legacy" ]]; then - echo "Using legacy client" - # legacy client does not expose mutate_row. Disable those tests - TEST_ARGS="-skip TestMutateRow_" -fi - -if [[ $CLIENT_TYPE != "async" ]]; then - echo "Using legacy client" - # sync and legacy client do not support concurrent streams - TEST_ARGS="$TEST_ARGS -skip _Generic_MultiStream " -fi - # Build and start the proxy in a separate process -PROXY_PORT=9999 pushd test_proxy nohup python test_proxy.py --port $PROXY_PORT --client_type=$CLIENT_TYPE & proxyPID=$! @@ -48,6 +33,7 @@ function cleanup() { trap cleanup EXIT # Run the conformance test +echo "running tests with args: $TEST_ARGS" pushd cloud-bigtable-clients-test/tests eval "go test -v -proxy_addr=:$PROXY_PORT $TEST_ARGS" RETURN_CODE=$? From 2ab2e40827fe5f9a0231db7c266f81a39e30fdce Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 27 Sep 2024 16:49:46 -0700 Subject: [PATCH 327/360] fixed typo --- .github/workflows/conformance.yaml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/conformance.yaml b/.github/workflows/conformance.yaml index e39c7b5a2..d4e992c8d 100644 --- a/.github/workflows/conformance.yaml +++ b/.github/workflows/conformance.yaml @@ -29,11 +29,12 @@ jobs: client-type: [ "async", "sync", "legacy" ] include: - client-type: "sync" - # sync and legacy client do not support concurrent streams + # sync client does not support concurrent streams test_args: "-skip _Generic_MultiStream" - client-type: "legacy" + # legacy client is synchtonous and does not support concurrent streams # legacy client does not expose mutate_row. Disable those tests - test_argss: "-skip _Generic_MultiStream -skip TestMutateRow_" + test_args: "-skip _Generic_MultiStream -skip TestMutateRow_" fail-fast: false name: "${{ matrix.client-type }} client / python ${{ matrix.py-version }} / test tag ${{ matrix.test-version }}" steps: From 14a5f2568c8862fab9991d092f492b80c112f770 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 24 Oct 2024 15:56:02 -0700 Subject: [PATCH 328/360] renamed with_formatter --- .cross_sync/generate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index 86d515765..b0d855b96 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -46,16 +46,16 @@ def __init__(self, output_path: str, ast_tree, header: str | None = None): self.tree = ast_tree self.header = header or "" - def render(self, with_black=True, save_to_disk: bool = False) -> str: + def render(self, with_formatter=True, save_to_disk: bool = False) -> str: """ Render the file to a string, and optionally save to disk Args: - with_black: whether to run the output through black before returning + with_formatter: whether to run the output through black before returning save_to_disk: whether to write the output to the file path """ full_str = self.header + ast.unparse(self.tree) - if with_black: + if with_formatter: import black # type: ignore import autoflake # type: ignore From 8830375d933ec86cb60b23fb7739a69996d95781 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 24 Oct 2024 16:08:58 -0700 Subject: [PATCH 329/360] fixed docstrings --- .cross_sync/transformers.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index d40614497..d502c22e9 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -16,23 +16,23 @@ async code into sync code. At a high level: -- The main entrypoint is CrossSyncClassDecoratorHandler, which is used to find classes -annotated with @CrossSync.export_sync. +- The main entrypoint is CrossSyncFileProcessor, which is used to find files in + the codebase that include __CROSS_SYNC_OUTPUT__, and transform them + according to the `CrossSync` annotations they contains - SymbolReplacer is used to swap out CrossSync.X with CrossSync._Sync_Impl.X -- RmAioFunctions is then called on the class, to strip out asyncio keywords -marked with CrossSync.rm_aio (using AsyncToSync to handle the actual transformation) -- Finally, CrossSyncMethodDecoratorHandler is called to find methods annotated -with AstDecorators, and call decorator.sync_ast_transform on each one to fully transform the class. +- RmAioFunctions is used to strip out asyncio keywords marked with CrossSync.rm_aio + (deferring to AsyncToSync to handle the actual transformation) +- StripAsyncConditionalBranches finds `if CrossSync.is_async:` conditionals, and strips out + the unneeded branch for the sync output """ from __future__ import annotations import ast -import copy import sys # add cross_sync to path sys.path.append("google/cloud/bigtable/data/_cross_sync") -from _decorators import AstDecorator, ConvertClass +from _decorators import AstDecorator class SymbolReplacer(ast.NodeTransformer): @@ -144,6 +144,7 @@ def visit_ListComp(self, node): generator.is_async = False return self.generic_visit(node) + class RmAioFunctions(ast.NodeTransformer): """ Visits all calls marked with CrossSync.rm_aio, and removes asyncio keywords @@ -235,15 +236,14 @@ class CrossSyncFileProcessor(ast.NodeTransformer): If found, the file is processed with the following steps: - Strip out asyncio keywords within CrossSync.rm_aio calls - transform classes and methods annotated with CrossSync decorators - - classes not marked with @CrossSync.export are discarded in sync version - statements behind CrossSync.is_async conditional branches are removed - - Replace remaining CrossSync statements with corresponding CrossSync._Sync calls + - Replace remaining CrossSync statements with corresponding CrossSync._Sync_Impl calls - save changes in an output file at path specified by __CROSS_SYNC_OUTPUT__ """ FILE_ANNOTATION = "__CROSS_SYNC_OUTPUT__" def get_output_path(self, node): - for i, n in enumerate(node.body): + for n in node.body: if isinstance(n, ast.Assign): for target in n.targets: if isinstance(target, ast.Name) and target.id == self.FILE_ANNOTATION: From 6be8180f12011f44907411be30ebdaef07c4a28d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 24 Oct 2024 16:26:12 -0700 Subject: [PATCH 330/360] cleaning up rm_aio --- .cross_sync/transformers.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/.cross_sync/transformers.py b/.cross_sync/transformers.py index d502c22e9..80e384361 100644 --- a/.cross_sync/transformers.py +++ b/.cross_sync/transformers.py @@ -149,31 +149,38 @@ class RmAioFunctions(ast.NodeTransformer): """ Visits all calls marked with CrossSync.rm_aio, and removes asyncio keywords """ + RM_AIO_FN_NAME = "rm_aio" + RM_AIO_CLASS_NAME = "CrossSync" def __init__(self): self.to_sync = AsyncToSync() + def _is_rm_aio_call(self, node) -> bool: + """ + Check if a node is a CrossSync.rm_aio call + """ + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + if node.func.attr == self.RM_AIO_FN_NAME and node.func.value.id == self.RM_AIO_CLASS_NAME: + return True + return False + def visit_Call(self, node): - if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and \ - node.func.attr == "rm_aio" and "CrossSync" in node.func.value.id: + if self._is_rm_aio_call(node): return self.visit(self.to_sync.visit(node.args[0])) return self.generic_visit(node) def visit_AsyncWith(self, node): """ - Async with statements are not fully wrapped by calls + `async with` statements can contain multiple async context managers. + + If any of them contains a CrossSync.rm_aio statement, convert into standard `with` statement """ - found_rmaio = False - for item in node.items: - if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute) and isinstance(item.context_expr.func.value, ast.Name) and \ - item.context_expr.func.attr == "rm_aio" and "CrossSync" in item.context_expr.func.value.id: - found_rmaio = True - break - if found_rmaio: + if any(self._is_rm_aio_call(item.context_expr) for item in node.items + ): new_node = ast.copy_location( ast.With( - [self.generic_visit(item) for item in node.items], - [self.generic_visit(stmt) for stmt in node.body], + [self.visit(item) for item in node.items], + [self.visit(stmt) for stmt in node.body], ), node, ) @@ -185,8 +192,7 @@ def visit_AsyncFor(self, node): Async for statements are not fully wrapped by calls """ it = node.iter - if isinstance(it, ast.Call) and isinstance(it.func, ast.Attribute) and isinstance(it.func.value, ast.Name) and \ - it.func.attr == "rm_aio" and "CrossSync" in it.func.value.id: + if self._is_rm_aio_call(it): return ast.copy_location( ast.For( self.visit(node.target), From 54e3007a805c016b37ef67f95f7e14c90a932020 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 24 Oct 2024 16:43:22 -0700 Subject: [PATCH 331/360] accept None as empty string --- .../bigtable/data/_cross_sync/_decorators.py | 6 +++--- .../_cross_sync/test_cross_sync_decorators.py | 16 ++++++++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/google/cloud/bigtable/data/_cross_sync/_decorators.py b/google/cloud/bigtable/data/_cross_sync/_decorators.py index 4e79331dd..2f4c8374f 100644 --- a/google/cloud/bigtable/data/_cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_cross_sync/_decorators.py @@ -201,7 +201,7 @@ def __init__( sync_name: str | None = None, *, replace_symbols: dict[str, str] | None = None, - docstring_format_vars: dict[str, tuple[str, str]] | None = None, + docstring_format_vars: dict[str, tuple[str | None, str | None]] | None = None, rm_aio: bool = False, add_mapping_for_name: str | None = None, ): @@ -209,10 +209,10 @@ def __init__( self.replace_symbols = replace_symbols docstring_format_vars = docstring_format_vars or {} self.async_docstring_format_vars = { - k: v[0] for k, v in docstring_format_vars.items() + k: v[0] or "" for k, v in docstring_format_vars.items() } self.sync_docstring_format_vars = { - k: v[1] for k, v in docstring_format_vars.items() + k: v[1] or "" for k, v in docstring_format_vars.items() } self.rm_aio = rm_aio self.add_mapping_for_name = add_mapping_for_name diff --git a/tests/unit/data/_cross_sync/test_cross_sync_decorators.py b/tests/unit/data/_cross_sync/test_cross_sync_decorators.py index fb35a5834..1f5bd4b0e 100644 --- a/tests/unit/data/_cross_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_cross_sync/test_cross_sync_decorators.py @@ -108,6 +108,10 @@ def test_class_decorator_adds_mapping(self): ["{A}", {"A": (1, 2)}, "1"], ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "1 3"], ["hello {world_var}", {"world_var": ("world", "moon")}, "hello world"], + ["{empty}", {"empty": ("", "")}, ""], + ["{empty}", {"empty": (None, None)}, ""], + ["maybe{empty}", {"empty": (None, "yes")}, "maybe"], + ["maybe{empty}", {"empty": (" no", None)}, "maybe no"] ], ) def test_class_decorator_docstring_update(self, docstring, format_vars, expected): @@ -123,8 +127,8 @@ class Class: assert Class.__doc__ == expected # check internal state instance = self._get_class()(sync_name="s", docstring_format_vars=format_vars) - async_replacements = {k: v[0] for k, v in format_vars.items()} - sync_replacements = {k: v[1] for k, v in format_vars.items()} + async_replacements = {k: v[0] or "" for k, v in format_vars.items()} + sync_replacements = {k: v[1] or "" for k, v in format_vars.items()} assert instance.async_docstring_format_vars == async_replacements assert instance.sync_docstring_format_vars == sync_replacements @@ -299,6 +303,10 @@ def test_async_decorator_no_docstring(self): ["{A}", {"A": (1, 2)}, "1"], ["{A} {B}", {"A": (1, 2), "B": (3, 4)}, "1 3"], ["hello {world_var}", {"world_var": ("world", "moon")}, "hello world"], + ["{empty}", {"empty": ("", "")}, ""], + ["{empty}", {"empty": (None, None)}, ""], + ["maybe{empty}", {"empty": (None, "yes")}, "maybe"], + ["maybe{empty}", {"empty": (" no", None)}, "maybe no"] ], ) def test_async_decorator_docstring_update(self, docstring, format_vars, expected): @@ -314,8 +322,8 @@ class Class: assert Class.__doc__ == expected # check internal state instance = self._get_class()(docstring_format_vars=format_vars) - async_replacements = {k: v[0] for k, v in format_vars.items()} - sync_replacements = {k: v[1] for k, v in format_vars.items()} + async_replacements = {k: v[0] or "" for k, v in format_vars.items()} + sync_replacements = {k: v[1] or "" for k, v in format_vars.items()} assert instance.async_docstring_format_vars == async_replacements assert instance.sync_docstring_format_vars == sync_replacements From 5811a98092b35df15fc5264a988ef25e3f0f3f51 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 24 Oct 2024 16:52:43 -0700 Subject: [PATCH 332/360] use None for empty string --- google/cloud/bigtable/data/_async/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index cb305ae60..d7efe13c2 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -118,11 +118,11 @@ class BigtableDataClientAsync(ClientWithProject): docstring_format_vars={ "LOOP_MESSAGE": ( "Client should be created within an async context (running event loop)", - "", + None, ), "RAISE_NO_LOOP": ( "RuntimeError: if called outside of an async context (no running event loop)", - "", + None, ), } ) From 480b139444f69cf143523331e5c1ae0982674f71 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Thu, 24 Oct 2024 23:53:15 +0000 Subject: [PATCH 333/360] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot?= =?UTF-8?q?=20post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- tests/unit/data/_cross_sync/test_cross_sync_decorators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/data/_cross_sync/test_cross_sync_decorators.py b/tests/unit/data/_cross_sync/test_cross_sync_decorators.py index 1f5bd4b0e..a9aa14d0a 100644 --- a/tests/unit/data/_cross_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_cross_sync/test_cross_sync_decorators.py @@ -111,7 +111,7 @@ def test_class_decorator_adds_mapping(self): ["{empty}", {"empty": ("", "")}, ""], ["{empty}", {"empty": (None, None)}, ""], ["maybe{empty}", {"empty": (None, "yes")}, "maybe"], - ["maybe{empty}", {"empty": (" no", None)}, "maybe no"] + ["maybe{empty}", {"empty": (" no", None)}, "maybe no"], ], ) def test_class_decorator_docstring_update(self, docstring, format_vars, expected): @@ -306,7 +306,7 @@ def test_async_decorator_no_docstring(self): ["{empty}", {"empty": ("", "")}, ""], ["{empty}", {"empty": (None, None)}, ""], ["maybe{empty}", {"empty": (None, "yes")}, "maybe"], - ["maybe{empty}", {"empty": (" no", None)}, "maybe no"] + ["maybe{empty}", {"empty": (" no", None)}, "maybe no"], ], ) def test_async_decorator_docstring_update(self, docstring, format_vars, expected): From 66fc807af4a1f3b6a150ca7139f148a78d4a0565 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 24 Oct 2024 16:55:19 -0700 Subject: [PATCH 334/360] move _MB_SIZE back to batcher --- google/cloud/bigtable/data/_async/mutations_batcher.py | 4 +++- google/cloud/bigtable/data/_helpers.py | 4 ---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 1780a902e..2603a9225 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -25,7 +25,6 @@ from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data.mutations import ( _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, @@ -44,6 +43,9 @@ __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen.mutations_batcher" +# used to make more readable default values +_MB_SIZE = 1024 * 1024 + @CrossSync.convert_class(sync_name="_FlowControl", add_mapping_for_name="_FlowControl") class _FlowControlAsync: diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index d09ccb204..dea31911e 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -48,10 +48,6 @@ "_WarmedInstanceKey", ["instance_name", "table_name", "app_profile_id"] ) -# used to make more readable default values -_MB_SIZE = 1024 * 1024 - - # enum used on method calls when table defaults should be used class TABLE_DEFAULT(enum.Enum): # default for mutate_row, sample_row_keys, check_and_mutate_row, and read_modify_write_row From 89a816a616fef1dc30c5cbfd614b9dd5e0b430b9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 25 Oct 2024 11:22:05 -0700 Subject: [PATCH 335/360] removed rm_aio call annotations --- .../bigtable/data/_async/_mutate_rows.py | 14 +- .../cloud/bigtable/data/_async/_read_rows.py | 3 +- google/cloud/bigtable/data/_async/client.py | 212 ++++++++---------- .../bigtable/data/_async/mutations_batcher.py | 36 ++- .../_async/execute_query_iterator.py | 28 +-- tests/system/data/test_system_async.py | 36 +-- tests/unit/data/_async/test_client.py | 4 +- .../data/_async/test_read_rows_acceptance.py | 4 +- .../_async/test_query_iterator.py | 2 +- 9 files changed, 137 insertions(+), 202 deletions(-) diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 705f03066..553fbf6a4 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -129,7 +129,7 @@ async def start(self): """ try: # trigger mutate_rows - CrossSync.rm_aio(await self._operation()) + await self._operation() except Exception as exc: # exceptions raised by retryable are added to the list of exceptions for all unfinalized mutations incomplete_indices = self.remaining_indices.copy() @@ -177,14 +177,12 @@ async def _run_attempt(self): return # make gapic request try: - result_generator = CrossSync.rm_aio( - await self._gapic_fn( - timeout=next(self.timeout_generator), - entries=request_entries, - retry=None, - ) + result_generator = await self._gapic_fn( + timeout=next(self.timeout_generator), + entries=request_entries, + retry=None, ) - async for result_list in CrossSync.rm_aio(result_generator): + async for result_list in result_generator: for result in result_list.entries: # convert sub-request index to global index orig_idx = active_request_indices[result.index] diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index d0ca493a5..1edd90fa0 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -156,7 +156,7 @@ def _read_rows_attempt(self) -> CrossSync.Iterable[Row]: chunked_stream = self.chunk_stream(gapic_stream) return self.merge_rows(chunked_stream) - @CrossSync.convert(rm_aio=True) + @CrossSync.convert() async def chunk_stream( self, stream: CrossSync.Awaitable[CrossSync.Iterable[ReadRowsResponsePB]] ) -> CrossSync.Iterable[ReadRowsResponsePB.CellChunk]: @@ -211,7 +211,6 @@ async def chunk_stream( @staticmethod @CrossSync.convert( replace_symbols={"__aiter__": "__iter__", "__anext__": "__next__"}, - rm_aio=True, ) async def merge_rows( chunks: CrossSync.Iterable[ReadRowsResponsePB.CellChunk] | None, diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index d7efe13c2..ee6359740 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -280,12 +280,10 @@ async def close(self, timeout: float | None = 2.0): self._is_closed.set() for task in self._channel_refresh_tasks: task.cancel() - CrossSync.rm_aio(await self.transport.close()) + await self.transport.close() if self._executor: self._executor.shutdown(wait=False) - CrossSync.rm_aio( - await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) - ) + await CrossSync.wait(self._channel_refresh_tasks, timeout=timeout) self._channel_refresh_tasks = [] @CrossSync.convert @@ -325,10 +323,8 @@ async def _ping_and_warm_instances( ) for (instance_name, table_name, app_profile_id) in instance_list ] - result_list = CrossSync.rm_aio( - await CrossSync.gather_partials( - partial_list, return_exceptions=True, sync_executor=self._executor - ) + result_list = await CrossSync.gather_partials( + partial_list, return_exceptions=True, sync_executor=self._executor ) return [r or None for r in result_list] @@ -366,31 +362,27 @@ async def _manage_channel( if next_sleep > 0: # warm the current channel immediately channel = self.transport.channels[channel_idx] - CrossSync.rm_aio(await self._ping_and_warm_instances(channel)) + await self._ping_and_warm_instances(channel) # continuously refresh the channel every `refresh_interval` seconds while not self._is_closed.is_set(): - CrossSync.rm_aio( - await CrossSync.event_wait( - self._is_closed, - next_sleep, - async_break_early=False, # no need to interrupt sleep. Task will be cancelled on close - ) + await CrossSync.event_wait( + self._is_closed, + next_sleep, + async_break_early=False, # no need to interrupt sleep. Task will be cancelled on close ) if self._is_closed.is_set(): # don't refresh if client is closed break # prepare new channel for use new_channel = self.transport.grpc_channel._create_channel() - CrossSync.rm_aio(await self._ping_and_warm_instances(new_channel)) + await self._ping_and_warm_instances(new_channel) # cycle channel out of use, with long grace window before closure start_timestamp = time.monotonic() - CrossSync.rm_aio( - await self.transport.replace_channel( - channel_idx, - grace=grace_period, - new_channel=new_channel, - event=self._is_closed, - ) + await self.transport.replace_channel( + channel_idx, + grace=grace_period, + new_channel=new_channel, + event=self._is_closed, ) # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) @@ -429,9 +421,7 @@ async def _register_instance( # refresh tasks already running # call ping and warm on all existing channels for channel in self.transport.channels: - CrossSync.rm_aio( - await self._ping_and_warm_instances(channel, instance_key) - ) + await self._ping_and_warm_instances(channel, instance_key) else: # refresh tasks aren't active. start them as background tasks self._start_background_channel_refresh() @@ -634,8 +624,8 @@ async def __aenter__(self): @CrossSync.convert(sync_name="__exit__", replace_symbols={"__aexit__": "__exit__"}) async def __aexit__(self, exc_type, exc_val, exc_tb): - CrossSync.rm_aio(await self.close()) - CrossSync.rm_aio(await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb)) + await self.close() + await self._gapic_client.__aexit__(exc_type, exc_val, exc_tb) @CrossSync.convert_class(sync_name="Table", add_mapping_for_name="Table") @@ -871,15 +861,13 @@ async def read_rows( from any retries that failed google.api_core.exceptions.GoogleAPIError: raised if the request encounters an unrecoverable error """ - row_generator = CrossSync.rm_aio( - await self.read_rows_stream( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, - ) + row_generator = await self.read_rows_stream( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, ) - return CrossSync.rm_aio([row async for row in row_generator]) + return [row async for row in row_generator] @CrossSync.convert async def read_row( @@ -921,13 +909,11 @@ async def read_row( if row_key is None: raise ValueError("row_key must be string or bytes") query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) - results = CrossSync.rm_aio( - await self.read_rows( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, - ) + results = await self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, ) if len(results) == 0: return None @@ -988,31 +974,27 @@ async def read_rows_sharded( @CrossSync.convert async def read_rows_with_semaphore(query): - async with CrossSync.rm_aio(concurrency_sem): + async with concurrency_sem: # calculate new timeout based on time left in overall operation shard_timeout = next(rpc_timeout_generator) if shard_timeout <= 0: raise DeadlineExceeded( "Operation timeout exceeded before starting query" ) - return CrossSync.rm_aio( - await self.read_rows( - query, - operation_timeout=shard_timeout, - attempt_timeout=min(attempt_timeout, shard_timeout), - retryable_errors=retryable_errors, - ) + return await self.read_rows( + query, + operation_timeout=shard_timeout, + attempt_timeout=min(attempt_timeout, shard_timeout), + retryable_errors=retryable_errors, ) routine_list = [ partial(read_rows_with_semaphore, query) for query in sharded_query ] - batch_result = CrossSync.rm_aio( - await CrossSync.gather_partials( - routine_list, - return_exceptions=True, - sync_executor=self.client._executor, - ) + batch_result = await CrossSync.gather_partials( + routine_list, + return_exceptions=True, + sync_executor=self.client._executor, ) # collect results and errors @@ -1081,13 +1063,11 @@ async def row_exists( limit_filter = CellsRowLimitFilter(1) chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) - results = CrossSync.rm_aio( - await self.read_rows( - query, - operation_timeout=operation_timeout, - attempt_timeout=attempt_timeout, - retryable_errors=retryable_errors, - ) + results = await self.read_rows( + query, + operation_timeout=operation_timeout, + attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, ) return len(results) > 0 @@ -1150,27 +1130,21 @@ async def sample_row_keys( @CrossSync.convert async def execute_rpc(): - results = CrossSync.rm_aio( - await self.client._gapic_client.sample_row_keys( - table_name=self.table_name, - app_profile_id=self.app_profile_id, - timeout=next(attempt_timeout_gen), - metadata=metadata, - retry=None, - ) - ) - return CrossSync.rm_aio( - [(s.row_key, s.offset_bytes) async for s in results] + results = await self.client._gapic_client.sample_row_keys( + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=next(attempt_timeout_gen), + metadata=metadata, + retry=None, ) + return [(s.row_key, s.offset_bytes) async for s in results] - return CrossSync.rm_aio( - await CrossSync.retry_target( - execute_rpc, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, - ) + return await CrossSync.retry_target( + execute_rpc, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, ) @CrossSync.convert(replace_symbols={"MutationsBatcherAsync": "MutationsBatcher"}) @@ -1296,14 +1270,12 @@ async def mutate_row( ), retry=None, ) - return CrossSync.rm_aio( - await CrossSync.retry_target( - target, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, - ) + return await CrossSync.retry_target( + target, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, ) @CrossSync.convert @@ -1360,7 +1332,7 @@ async def bulk_mutate_rows( attempt_timeout, retryable_exceptions=retryable_excs, ) - CrossSync.rm_aio(await operation.start()) + await operation.start() @CrossSync.convert async def check_and_mutate_row( @@ -1418,20 +1390,18 @@ async def check_and_mutate_row( metadata = _make_metadata( self.table_name, self.app_profile_id, instance_name=None ) - result = CrossSync.rm_aio( - await self.client._gapic_client.check_and_mutate_row( - true_mutations=true_case_list, - false_mutations=false_case_list, - predicate_filter=predicate._to_pb() if predicate is not None else None, - row_key=row_key.encode("utf-8") - if isinstance(row_key, str) - else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, - ) + result = await self.client._gapic_client.check_and_mutate_row( + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, ) return result.predicate_matched @@ -1476,18 +1446,16 @@ async def read_modify_write_row( metadata = _make_metadata( self.table_name, self.app_profile_id, instance_name=None ) - result = CrossSync.rm_aio( - await self.client._gapic_client.read_modify_write_row( - rules=[rule._to_pb() for rule in rules], - row_key=row_key.encode("utf-8") - if isinstance(row_key, str) - else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, - ) + result = await self.client._gapic_client.read_modify_write_row( + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, ) # construct Row from result return Row._from_pb(result.row) @@ -1499,9 +1467,7 @@ async def close(self): """ if self._register_instance_future: self._register_instance_future.cancel() - CrossSync.rm_aio( - await self.client._remove_instance_registration(self.instance_id, self) - ) + await self.client._remove_instance_registration(self.instance_id, self) @CrossSync.convert(sync_name="__enter__") async def __aenter__(self): @@ -1512,7 +1478,7 @@ async def __aenter__(self): grpc channels will be warmed for the specified instance """ if self._register_instance_future: - CrossSync.rm_aio(await self._register_instance_future) + await self._register_instance_future return self @CrossSync.convert(sync_name="__exit__") @@ -1523,4 +1489,4 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): Unregister this instance with the client, so that grpc channels will no longer be warmed """ - CrossSync.rm_aio(await self.close()) + await self.close() \ No newline at end of file diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 2603a9225..d8ecb7d32 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -123,7 +123,7 @@ async def remove_from_flow( self._in_flight_mutation_count -= total_count self._in_flight_mutation_bytes -= total_size # notify any blocked requests that there is additional capacity - async with CrossSync.rm_aio(self._capacity_condition): + async with self._capacity_condition: self._capacity_condition.notify_all() @CrossSync.convert @@ -149,7 +149,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] start_idx = end_idx batch_mutation_count = 0 # fill up batch until we hit capacity - async with CrossSync.rm_aio(self._capacity_condition): + async with self._capacity_condition: while end_idx < len(mutations): next_entry = mutations[end_idx] next_size = next_entry.size() @@ -170,10 +170,8 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] break else: # batch is empty. Block until we have capacity - CrossSync.rm_aio( - await self._capacity_condition.wait_for( - lambda: self._has_capacity(next_count, next_size) - ) + await self._capacity_condition.wait_for( + lambda: self._has_capacity(next_count, next_size) ) yield mutations[start_idx:end_idx] @@ -289,10 +287,8 @@ async def _timer_routine(self, interval: float | None) -> None: return None while not self._closed.is_set(): # wait until interval has passed, or until closed - CrossSync.rm_aio( - await CrossSync.event_wait( - self._closed, timeout=interval, async_break_early=False - ) + await CrossSync.event_wait( + self._closed, timeout=interval, async_break_early=False ) if not self._closed.is_set() and self._staged_entries: self._schedule_flush() @@ -325,7 +321,7 @@ async def append(self, mutation_entry: RowMutationEntry): ): self._schedule_flush() # yield to the event loop to allow flush to run - CrossSync.rm_aio(await CrossSync.yield_to_event_loop()) + await CrossSync.yield_to_event_loop() def _schedule_flush(self) -> CrossSync.Future[None] | None: """ @@ -357,17 +353,13 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): """ # flush new entries in_process_requests: list[CrossSync.Future[list[FailedMutationEntryError]]] = [] - async for batch in CrossSync.rm_aio( - self._flow_control.add_to_flow(new_entries) - ): + async for batch in self._flow_control.add_to_flow(new_entries): batch_task = CrossSync.create_task( self._execute_mutate_rows, batch, sync_executor=self._sync_rpc_executor ) in_process_requests.append(batch_task) # wait for all inflight requests to complete - found_exceptions = CrossSync.rm_aio( - await self._wait_for_batch_results(*in_process_requests) - ) + found_exceptions = await self._wait_for_batch_results(*in_process_requests) # update exception data to reflect any new errors self._entries_processed_since_last_raise += len(new_entries) self._add_exceptions(found_exceptions) @@ -397,7 +389,7 @@ async def _execute_mutate_rows( attempt_timeout=self._attempt_timeout, retryable_exceptions=self._retryable_errors, ) - CrossSync.rm_aio(await operation.start()) + await operation.start() except MutationsExceptionGroup as e: # strip index information from exceptions, since it is not useful in a batch context for subexc in e.exceptions: @@ -405,7 +397,7 @@ async def _execute_mutate_rows( return list(e.exceptions) finally: # mark batch as complete in flow control - CrossSync.rm_aio(await self._flow_control.remove_from_flow(batch)) + await self._flow_control.remove_from_flow(batch) return [] def _add_exceptions(self, excs: list[Exception]): @@ -465,7 +457,7 @@ async def __aexit__(self, exc_type, exc, tb): Flushes the batcher and cleans up resources. """ - CrossSync.rm_aio(await self.close()) + await self.close() @property def closed(self) -> bool: @@ -490,7 +482,7 @@ async def close(self): if self._sync_rpc_executor: with self._sync_rpc_executor: self._sync_rpc_executor.shutdown(wait=True) - CrossSync.rm_aio(await CrossSync.wait([*self._flush_jobs, self._flush_timer])) + await CrossSync.wait([*self._flush_jobs, self._flush_timer]) atexit.unregister(self._on_exit) # raise unreported exceptions self._raise_exceptions() @@ -530,7 +522,7 @@ async def _wait_for_batch_results( for task in tasks: if CrossSync.is_async: # futures don't need to be awaited in sync mode - CrossSync.rm_aio(await task) + await task try: exc_list = task.result() if exc_list: diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index d58312fa5..41e091190 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -159,13 +159,11 @@ async def _make_request_with_resume_token(self): "resume_token": resume_token, } ) - return CrossSync.rm_aio( - await self._client._gapic_client.execute_query( - request, - timeout=next(self._attempt_timeout_gen), - metadata=self._req_metadata, - retry=None, - ) + return await self._client._gapic_client.execute_query( + request, + timeout=next(self._attempt_timeout_gen), + metadata=self._req_metadata, + retry=None, ) @CrossSync.convert(replace_symbols={"__anext__": "__next__"}) @@ -175,7 +173,7 @@ async def _fetch_metadata(self) -> None: is retrieved as part of this call. """ if self._byte_cursor.metadata is None: - metadata_msg = CrossSync.rm_aio(await self._stream.__anext__()) + metadata_msg = await self._stream.__anext__() self._byte_cursor.consume_metadata(metadata_msg) @CrossSync.convert @@ -184,9 +182,9 @@ async def _next_impl(self) -> CrossSync.Iterator[QueryResultRow]: Generator wrapping the response stream which parses the stream results and returns full `QueryResultRow`s. """ - CrossSync.rm_aio(await self._fetch_metadata()) + await self._fetch_metadata() - async for response in CrossSync.rm_aio(self._stream): + async for response in self._stream: try: bytes_to_parse = self._byte_cursor.consume(response) if bytes_to_parse is None: @@ -203,13 +201,13 @@ async def _next_impl(self) -> CrossSync.Iterator[QueryResultRow]: for result in results: yield result - CrossSync.rm_aio(await self.close()) + await self.close() @CrossSync.convert(sync_name="__next__", replace_symbols={"__anext__": "__next__"}) async def __anext__(self) -> QueryResultRow: if self._is_closed: raise CrossSync.StopIteration - return CrossSync.rm_aio(await self._result_generator.__anext__()) + return await self._result_generator.__anext__() @CrossSync.convert(sync_name="__iter__") def __aiter__(self): @@ -226,7 +224,7 @@ async def metadata(self) -> Optional[Metadata]: # Metadata should be present in the first response in a stream. if self._byte_cursor.metadata is None: try: - CrossSync.rm_aio(await self._fetch_metadata()) + await self._fetch_metadata() except CrossSync.StopIteration: return None return self._byte_cursor.metadata @@ -241,6 +239,4 @@ async def close(self) -> None: self._is_closed = True if self._register_instance_task is not None: self._register_instance_task.cancel() - CrossSync.rm_aio( - await self._client._remove_instance_registration(self._instance_id, self) - ) + await self._client._remove_instance_registration(self._instance_id, self) diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index ee05acee7..804cb3bfd 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -64,7 +64,7 @@ async def add_row( } ], } - CrossSync.rm_aio(await self.table.client._gapic_client.mutate_row(request)) + await self.table.client._gapic_client.mutate_row(request) self.rows.append(row_key) @CrossSync.convert @@ -77,7 +77,7 @@ async def delete_rows(self): for row in self.rows ], } - CrossSync.rm_aio(await self.table.client._gapic_client.mutate_rows(request)) + await self.table.client._gapic_client.mutate_rows(request) @CrossSync.convert_class(sync_name="TestSystem") @@ -86,20 +86,13 @@ class TestSystemAsync: @CrossSync.pytest_fixture(scope="session") async def client(self): project = os.getenv("GOOGLE_CLOUD_PROJECT") or None - async with CrossSync.rm_aio( - CrossSync.DataClient(project=project, pool_size=4) - ) as client: + async with CrossSync.DataClient(project=project, pool_size=4) as client: yield client @CrossSync.convert @CrossSync.pytest_fixture(scope="session") async def table(self, client, table_id, instance_id): - async with CrossSync.rm_aio( - client.get_table( - instance_id, - table_id, - ) - ) as table: + async with client.get_table(instance_id, table_id) as table: yield table @CrossSync.drop @@ -149,9 +142,7 @@ async def _retrieve_cell_value(self, table, row_key): """ from google.cloud.bigtable.data import ReadRowsQuery - row_list = CrossSync.rm_aio( - await table.read_rows(ReadRowsQuery(row_keys=row_key)) - ) + row_list = await table.read_rows(ReadRowsQuery(row_keys=row_key)) assert len(row_list) == 1 row = row_list[0] cell = row.cells[0] @@ -169,16 +160,11 @@ async def _create_row_and_mutation( row_key = uuid.uuid4().hex.encode() family = TEST_FAMILY qualifier = b"test-qualifier" - CrossSync.rm_aio( - await temp_rows.add_row( - row_key, family=family, qualifier=qualifier, value=start_value - ) + await temp_rows.add_row( + row_key, family=family, qualifier=qualifier, value=start_value ) # ensure cell is initialized - assert ( - CrossSync.rm_aio(await self._retrieve_cell_value(table, row_key)) - == start_value - ) + assert await self._retrieve_cell_value(table, row_key) == start_value mutation = SetCell(family=TEST_FAMILY, qualifier=qualifier, new_value=new_value) return row_key, mutation @@ -188,7 +174,7 @@ async def _create_row_and_mutation( async def temp_rows(self, table): builder = CrossSync.TempRowBuilder(table) yield builder - CrossSync.rm_aio(await builder.delete_rows()) + await builder.delete_rows() @pytest.mark.usefixtures("table") @pytest.mark.usefixtures("client") @@ -234,9 +220,7 @@ async def test_mutation_set_cell(self, table, temp_rows): """ row_key = b"bulk_mutate" new_value = uuid.uuid4().hex.encode() - row_key, mutation = CrossSync.rm_aio( - await self._create_row_and_mutation(table, temp_rows, new_value=new_value) - ) + row_key, mutation = await self._create_row_and_mutation(table, temp_rows, new_value=new_value) await table.mutate_row(row_key, mutation) # ensure cell is updated diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 1c91d7992..0d7850975 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1511,7 +1511,7 @@ async def __anext__(self): self.idx += 1 if len(self.chunk_list) > self.idx: if sleep_time: - CrossSync.rm_aio(await CrossSync.sleep(self.sleep_time)) + await CrossSync.sleep(self.sleep_time) chunk = self.chunk_list[self.idx] if isinstance(chunk, Exception): raise chunk @@ -1526,7 +1526,7 @@ def cancel(self): @CrossSync.convert async def execute_fn(self, table, *args, **kwargs): - return CrossSync.rm_aio(await table.read_rows(*args, **kwargs)) + return await table.read_rows(*args, **kwargs) @CrossSync.pytest async def test_read_rows(self): diff --git a/tests/unit/data/_async/test_read_rows_acceptance.py b/tests/unit/data/_async/test_read_rows_acceptance.py index 6ec783069..45d139182 100644 --- a/tests/unit/data/_async/test_read_rows_acceptance.py +++ b/tests/unit/data/_async/test_read_rows_acceptance.py @@ -91,7 +91,7 @@ async def _row_stream(): ) merger = self._get_operation_class().merge_rows(chunker) results = [] - async for row in CrossSync.rm_aio(merger): + async for row in merger: results.append(row) return results @@ -113,7 +113,7 @@ async def _scenerio_stream(): instance, self._coro_wrapper(_scenerio_stream()) ) merger = self._get_operation_class().merge_rows(chunker) - async for row in CrossSync.rm_aio(merger): + async for row in merger: for cell in row: cell_result = ReadRowsTest.Result( row_key=cell.row_key, diff --git a/tests/unit/data/execute_query/_async/test_query_iterator.py b/tests/unit/data/execute_query/_async/test_query_iterator.py index 75b8a2c8e..fd99b1e10 100644 --- a/tests/unit/data/execute_query/_async/test_query_iterator.py +++ b/tests/unit/data/execute_query/_async/test_query_iterator.py @@ -46,7 +46,7 @@ async def __anext__(self): if self.idx >= len(self._values): raise CrossSync.StopIteration if self._delay is not None: - CrossSync.rm_aio(await CrossSync.sleep(self._delay)) + await CrossSync.sleep(self._delay) value = self._values[self.idx] self.idx += 1 return value From c6f053bd7de5e000c7e27ae93e3ce6396383a4fd Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 25 Oct 2024 11:22:18 -0700 Subject: [PATCH 336/360] fixed broken imports --- google/cloud/bigtable/data/__init__.py | 2 +- google/cloud/bigtable/data/_async/client.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index e176ec765..8a8fa35c7 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -61,7 +61,7 @@ ExecuteQueryIteratorAsync, ) -from google.cloud.bigtable.data._sync.cross_sync import CrossSync +from google.cloud.bigtable.data._cross_sync import CrossSync CrossSync.add_mapping("GapicClient", BigtableAsyncClient) CrossSync.add_mapping("PooledChannel", AsyncPooledChannel) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index ee6359740..24f8fd451 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -70,7 +70,6 @@ from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import _attempt_timeout_generator -from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule @@ -86,7 +85,7 @@ PooledBigtableGrpcAsyncIOTransport as PooledTransportType, ) from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, + MutationsBatcherAsync, _MB_SIZE ) from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( ExecuteQueryIteratorAsync, @@ -95,7 +94,7 @@ else: from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport as PooledTransportType # type: ignore from google.cloud.bigtable.data._sync_autogen.mutations_batcher import ( # noqa: F401 - MutationsBatcher, + MutationsBatcher, _MB_SIZE ) from google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator import ( # noqa: F401 ExecuteQueryIterator, From bee3e84de3465aa4cfa8c98ebb5540882f450e4a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 25 Oct 2024 11:27:24 -0700 Subject: [PATCH 337/360] rm_aio at function level by default --- google/cloud/bigtable/data/_cross_sync/_decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_cross_sync/_decorators.py b/google/cloud/bigtable/data/_cross_sync/_decorators.py index 2f4c8374f..e87b6339d 100644 --- a/google/cloud/bigtable/data/_cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_cross_sync/_decorators.py @@ -309,7 +309,7 @@ def __init__( *, replace_symbols: dict[str, str] | None = None, docstring_format_vars: dict[str, tuple[str, str]] | None = None, - rm_aio: bool = False, + rm_aio: bool = True, ): super().__init__( sync_name=sync_name, From 36c78ba7fc0660b9ce68af91b127decced688f8a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 25 Oct 2024 11:27:53 -0700 Subject: [PATCH 338/360] render to disk by default --- .cross_sync/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.cross_sync/generate.py b/.cross_sync/generate.py index b0d855b96..5158d0f37 100644 --- a/.cross_sync/generate.py +++ b/.cross_sync/generate.py @@ -46,7 +46,7 @@ def __init__(self, output_path: str, ast_tree, header: str | None = None): self.tree = ast_tree self.header = header or "" - def render(self, with_formatter=True, save_to_disk: bool = False) -> str: + def render(self, with_formatter=True, save_to_disk: bool = True) -> str: """ Render the file to a string, and optionally save to disk From 3c44095c5165c187ab4f1baeccb47e7a7b0e1861 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 25 Oct 2024 11:40:40 -0700 Subject: [PATCH 339/360] fixed lint --- google/cloud/bigtable/data/_async/client.py | 16 +++++++--------- google/cloud/bigtable/data/_helpers.py | 1 + noxfile.py | 8 +++++++- tests/system/data/test_system_async.py | 4 +++- .../_cross_sync/test_cross_sync_decorators.py | 4 ++-- .../execute_query/_async/test_query_iterator.py | 4 +++- 6 files changed, 23 insertions(+), 14 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 24f8fd451..d3ba0acb6 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -85,7 +85,8 @@ PooledBigtableGrpcAsyncIOTransport as PooledTransportType, ) from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, _MB_SIZE + MutationsBatcherAsync, + _MB_SIZE, ) from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( ExecuteQueryIteratorAsync, @@ -94,7 +95,8 @@ else: from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import PooledBigtableGrpcTransport as PooledTransportType # type: ignore from google.cloud.bigtable.data._sync_autogen.mutations_batcher import ( # noqa: F401 - MutationsBatcher, _MB_SIZE + MutationsBatcher, + _MB_SIZE, ) from google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator import ( # noqa: F401 ExecuteQueryIterator, @@ -1393,9 +1395,7 @@ async def check_and_mutate_row( true_mutations=true_case_list, false_mutations=false_case_list, predicate_filter=predicate._to_pb() if predicate is not None else None, - row_key=row_key.encode("utf-8") - if isinstance(row_key, str) - else row_key, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, table_name=self.table_name, app_profile_id=self.app_profile_id, metadata=metadata, @@ -1447,9 +1447,7 @@ async def read_modify_write_row( ) result = await self.client._gapic_client.read_modify_write_row( rules=[rule._to_pb() for rule in rules], - row_key=row_key.encode("utf-8") - if isinstance(row_key, str) - else row_key, + row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, table_name=self.table_name, app_profile_id=self.app_profile_id, metadata=metadata, @@ -1488,4 +1486,4 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): Unregister this instance with the client, so that grpc channels will no longer be warmed """ - await self.close() \ No newline at end of file + await self.close() diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index dea31911e..2d36c521f 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -48,6 +48,7 @@ "_WarmedInstanceKey", ["instance_name", "table_name", "app_profile_id"] ) + # enum used on method calls when table defaults should be used class TABLE_DEFAULT(enum.Enum): # default for mutate_row, sample_row_keys, check_and_mutate_row, and read_modify_write_row diff --git a/noxfile.py b/noxfile.py index a3964b0f5..fe1a089df 100644 --- a/noxfile.py +++ b/noxfile.py @@ -297,7 +297,13 @@ def conformance(session, client_type): install_unittest_dependencies(session, "-c", constraints_path) with session.chdir("test_proxy"): # download the conformance test suite - session.run("bash", "-e", "run_tests.sh", external=True, env={"CLIENT_TYPE": client_type}) + session.run( + "bash", + "-e", + "run_tests.sh", + external=True, + env={"CLIENT_TYPE": client_type}, + ) @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 804cb3bfd..e856910b2 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -220,7 +220,9 @@ async def test_mutation_set_cell(self, table, temp_rows): """ row_key = b"bulk_mutate" new_value = uuid.uuid4().hex.encode() - row_key, mutation = await self._create_row_and_mutation(table, temp_rows, new_value=new_value) + row_key, mutation = await self._create_row_and_mutation( + table, temp_rows, new_value=new_value + ) await table.mutate_row(row_key, mutation) # ensure cell is updated diff --git a/tests/unit/data/_cross_sync/test_cross_sync_decorators.py b/tests/unit/data/_cross_sync/test_cross_sync_decorators.py index 1f5bd4b0e..a9aa14d0a 100644 --- a/tests/unit/data/_cross_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_cross_sync/test_cross_sync_decorators.py @@ -111,7 +111,7 @@ def test_class_decorator_adds_mapping(self): ["{empty}", {"empty": ("", "")}, ""], ["{empty}", {"empty": (None, None)}, ""], ["maybe{empty}", {"empty": (None, "yes")}, "maybe"], - ["maybe{empty}", {"empty": (" no", None)}, "maybe no"] + ["maybe{empty}", {"empty": (" no", None)}, "maybe no"], ], ) def test_class_decorator_docstring_update(self, docstring, format_vars, expected): @@ -306,7 +306,7 @@ def test_async_decorator_no_docstring(self): ["{empty}", {"empty": ("", "")}, ""], ["{empty}", {"empty": (None, None)}, ""], ["maybe{empty}", {"empty": (None, "yes")}, "maybe"], - ["maybe{empty}", {"empty": (" no", None)}, "maybe no"] + ["maybe{empty}", {"empty": (" no", None)}, "maybe no"], ], ) def test_async_decorator_docstring_update(self, docstring, format_vars, expected): diff --git a/tests/unit/data/execute_query/_async/test_query_iterator.py b/tests/unit/data/execute_query/_async/test_query_iterator.py index fd99b1e10..9bdf17c27 100644 --- a/tests/unit/data/execute_query/_async/test_query_iterator.py +++ b/tests/unit/data/execute_query/_async/test_query_iterator.py @@ -27,7 +27,9 @@ import mock # type: ignore -__CROSS_SYNC_OUTPUT__ = "tests.unit.data.execute_query._sync_autogen.test_query_iterator" +__CROSS_SYNC_OUTPUT__ = ( + "tests.unit.data.execute_query._sync_autogen.test_query_iterator" +) @CrossSync.convert_class(sync_name="MockIterator") From 6135ccced7b3791ceca70c509b7c7b968d9ae593 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 25 Oct 2024 11:42:49 -0700 Subject: [PATCH 340/360] fixed mypy issue --- google/cloud/bigtable/data/_cross_sync/_decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/data/_cross_sync/_decorators.py b/google/cloud/bigtable/data/_cross_sync/_decorators.py index e87b6339d..f37b05b64 100644 --- a/google/cloud/bigtable/data/_cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_cross_sync/_decorators.py @@ -308,7 +308,7 @@ def __init__( sync_name: str | None = None, *, replace_symbols: dict[str, str] | None = None, - docstring_format_vars: dict[str, tuple[str, str]] | None = None, + docstring_format_vars: dict[str, tuple[str | None, str | None]] | None = None, rm_aio: bool = True, ): super().__init__( From ff9d01974a07c68b23cf8a0a6ef3f263f1f34370 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 25 Oct 2024 11:55:12 -0700 Subject: [PATCH 341/360] fixed unit tests --- .../data/_cross_sync/test_cross_sync_decorators.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unit/data/_cross_sync/test_cross_sync_decorators.py b/tests/unit/data/_cross_sync/test_cross_sync_decorators.py index a9aa14d0a..3be579379 100644 --- a/tests/unit/data/_cross_sync/test_cross_sync_decorators.py +++ b/tests/unit/data/_cross_sync/test_cross_sync_decorators.py @@ -266,13 +266,13 @@ def test_ctor_defaults(self): assert instance.replace_symbols is None assert instance.async_docstring_format_vars == {} assert instance.sync_docstring_format_vars == {} - assert instance.rm_aio is False + assert instance.rm_aio is True def test_ctor(self): sync_name = "sync_name" replace_symbols = {"a": "b"} docstring_format_vars = {"A": (1, 2)} - rm_aio = True + rm_aio = False instance = self._get_class()( sync_name=sync_name, @@ -331,7 +331,7 @@ def test_sync_ast_transform_remove_adef(self): """ Should convert `async def` methods to `def` methods """ - decorator = self._get_class()() + decorator = self._get_class()(rm_aio=False) mock_node = ast.AsyncFunctionDef( name="test_method", args=ast.arguments(), body=[] ) @@ -345,7 +345,7 @@ def test_sync_ast_transform_replaces_name(self, globals_mock): """ Should update the name of the method if sync_name is set """ - decorator = self._get_class()(sync_name="new_method_name") + decorator = self._get_class()(sync_name="new_method_name", rm_aio=False) mock_node = ast.AsyncFunctionDef( name="old_method_name", args=ast.arguments(), body=[] ) @@ -375,7 +375,7 @@ def test_sync_ast_transform_replace_symbols(self): Should call SymbolReplacer with replace_symbols if replace_symbols is set """ replace_symbols = {"old_symbol": "new_symbol"} - decorator = self._get_class()(replace_symbols=replace_symbols) + decorator = self._get_class()(replace_symbols=replace_symbols, rm_aio=False) mock_node = ast.AsyncFunctionDef( name="test_method", args=ast.arguments(), body=[] ) @@ -405,7 +405,7 @@ def test_sync_ast_transform_add_docstring_format( """ If docstring_format_vars is set, should format the docstring of the new method """ - decorator = self._get_class()(docstring_format_vars=format_vars) + decorator = self._get_class()(docstring_format_vars=format_vars, rm_aio=False) mock_node = ast.AsyncFunctionDef( name="test_method", args=ast.arguments(), From 26eeb0c45692b903d17a61933c79b677af75865e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 31 Oct 2024 14:52:16 -0700 Subject: [PATCH 342/360] regnerated sync files --- google/cloud/bigtable/data/_sync_autogen/client.py | 6 ++++-- .../cloud/bigtable/data/_sync_autogen/mutations_batcher.py | 2 +- tests/unit/data/_sync_autogen/test_client.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index e9b1e564f..37e1349bd 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -59,7 +59,6 @@ from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import _attempt_timeout_generator -from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data.mutations import Mutation, RowMutationEntry from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule from google.cloud.bigtable.data.row_filters import RowFilter @@ -71,7 +70,10 @@ from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( PooledBigtableGrpcTransport as PooledTransportType, ) -from google.cloud.bigtable.data._sync_autogen.mutations_batcher import MutationsBatcher +from google.cloud.bigtable.data._sync_autogen.mutations_batcher import ( + MutationsBatcher, + _MB_SIZE, +) from google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator import ( ExecuteQueryIterator, ) diff --git a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py index 2779ffd92..c4f47b41c 100644 --- a/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py +++ b/google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py @@ -26,7 +26,6 @@ from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import TABLE_DEFAULT -from google.cloud.bigtable.data._helpers import _MB_SIZE from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT from google.cloud.bigtable.data.mutations import Mutation from google.cloud.bigtable.data._cross_sync import CrossSync @@ -34,6 +33,7 @@ if TYPE_CHECKING: from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data._sync_autogen.client import Table as TableType +_MB_SIZE = 1024 * 1024 @CrossSync._Sync_Impl.add_mapping_decorator("_FlowControl") diff --git a/tests/unit/data/_sync_autogen/test_client.py b/tests/unit/data/_sync_autogen/test_client.py index 419c3a5b0..d31a448ab 100644 --- a/tests/unit/data/_sync_autogen/test_client.py +++ b/tests/unit/data/_sync_autogen/test_client.py @@ -2734,7 +2734,7 @@ def __next__(self): raise value return value - async def __anext__(self): + def __anext__(self): return self.__next__() return MockStream(sample_list) From c95ca6823d8389698bb93169d34d6f725ddc4226 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 8 Nov 2024 14:20:52 -0800 Subject: [PATCH 343/360] addressing broken tests --- google/cloud/bigtable/data/_async/client.py | 4 +- tests/unit/data/_async/test_client.py | 72 ++++++++++----------- 2 files changed, 35 insertions(+), 41 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 0a37ff2e1..689202267 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -361,8 +361,8 @@ async def _manage_channel( start_timestamp = time.monotonic() # prepare new channel for use old_channel = self.transport.grpc_channel - new_channel = self.transport.grpc_channel._create_channel() - await self._ping_and_warm_instances(new_channel) + new_channel = self.transport.create_channel() + await self._ping_and_warm_instances(channel=new_channel) # cycle channel out of use, with long grace window before closure self.transport._grpc_channel = new_channel await old_channel.close(grace_period) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index c43570d05..a8f1ad42b 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -106,7 +106,6 @@ async def test_ctor_super_inits(self): client_options = {"api_endpoint": "foo.bar:1234"} options_parsed = client_options_lib.from_dict(client_options) asyncio_portion = "-async" if CrossSync.is_async else "" - transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" with mock.patch.object( CrossSync.GapicClient, "__init__" ) as bigtable_client_init: @@ -503,7 +502,7 @@ async def test__manage_channel_refresh(self, num_cycles): grpc_lib = grpc.aio if CrossSync.is_async else grpc new_channel = grpc_lib.insecure_channel("localhost:8080") - with mock.patch.object(asyncio, "sleep") as sleep: + with mock.patch.object(CrossSync, "event_wait") as sleep: sleep.side_effect = [None for i in range(num_cycles)] + [ asyncio.CancelledError ] @@ -602,7 +601,7 @@ async def test__register_instance_duplicate(self): instance_owners = {} client_mock._active_instances = active_instances client_mock._instance_owners = instance_owners - client_mock._channel_refresh_tasks = [object()] + client_mock._channel_refresh_task = object() mock_channels = [mock.Mock()] client_mock.transport.channels = mock_channels client_mock._ping_and_warm_instances = CrossSync.Mock() @@ -659,12 +658,7 @@ async def test__register_instance_state( instance_owners = {} client_mock._active_instances = active_instances client_mock._instance_owners = instance_owners - client_mock._channel_refresh_tasks = [] - client_mock._start_background_channel_refresh.side_effect = ( - lambda: client_mock._channel_refresh_tasks.append(mock.Mock) - ) - mock_channels = [mock.Mock() for i in range(5)] - client_mock.transport.channels = mock_channels + client_mock._channel_refresh_task = None client_mock._ping_and_warm_instances = CrossSync.Mock() table_mock = mock.Mock() # register instances @@ -951,7 +945,6 @@ async def test_close(self): async def test_close_with_timeout(self): expected_timeout = 19 client = self._make_client(project="project-id", use_emulator=False) - tasks = list(client._channel_refresh_tasks) with mock.patch.object(CrossSync, "wait", CrossSync.Mock()) as wait_for_mock: await client.close(timeout=expected_timeout) wait_for_mock.assert_called_once() @@ -1275,36 +1268,37 @@ async def test_customizable_retryable_errors( @CrossSync.pytest @CrossSync.convert async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): - """check that all requests attach proper metadata headers""" - profile = "profile" if include_app_profile else None - with mock.patch.object( - CrossSync.GapicClient, gapic_fn, CrossSync.Mock() - ) as gapic_mock: - gapic_mock.side_effect = RuntimeError("stop early") - async with self._make_client() as client: - table = self._get_target_class()( - client, "instance-id", "table-id", profile - ) - try: - test_fn = table.__getattribute__(fn_name) - maybe_stream = await test_fn(*fn_args) - [i async for i in maybe_stream] - except Exception: - # we expect an exception from attempting to call the mock - pass - kwargs = gapic_mock.call_args_list[0][1] - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata + from google.cloud.bigtable.data import TableAsync + profile = "profile" if include_app_profile else None + client = self._make_client() + # create mock for rpc stub + transport_mock = mock.MagicMock() + rpc_mock = mock.AsyncMock() + transport_mock._wrapped_methods.__getitem__.return_value = rpc_mock + client._gapic_client._client._transport = transport_mock + client._gapic_client._client._is_universe_domain_valid = True + table = self._get_target_class()(client, "instance-id", "table-id", profile) + try: + test_fn = table.__getattribute__(fn_name) + maybe_stream = await test_fn(*fn_args) + [i async for i in maybe_stream] + except Exception: + # we expect an exception from attempting to call the mock + pass + assert rpc_mock.call_count == 1 + kwargs = rpc_mock.call_args_list[0].kwargs + metadata = kwargs["metadata"] + # expect single metadata entry + assert len(metadata) == 1 + # expect x-goog-request-params tag + assert metadata[0][0] == "x-goog-request-params" + routing_str = metadata[0][1] + assert "table_name=" + table.table_name in routing_str + if include_app_profile: + assert "app_profile_id=profile" in routing_str + else: + assert "app_profile_id=" not in routing_str @CrossSync.convert_class( "TestReadRows", From e7ce0d0b8254e2e726e56b47c22f5986cd8ffb12 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 11 Nov 2024 13:27:14 -0800 Subject: [PATCH 344/360] regenerated files --- .../data/_sync_autogen/_mutate_rows.py | 5 - .../bigtable/data/_sync_autogen/_read_rows.py | 10 +- .../bigtable/data/_sync_autogen/client.py | 133 ++---- .../_sync_autogen/execute_query_iterator.py | 6 +- tests/system/data/test_system_autogen.py | 6 +- .../data/_sync_autogen/test__mutate_rows.py | 7 +- .../data/_sync_autogen/test__read_rows.py | 6 - tests/unit/data/_sync_autogen/test_client.py | 442 ++++++------------ 8 files changed, 205 insertions(+), 410 deletions(-) diff --git a/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py b/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py index 7f488db5f..0f6edb395 100644 --- a/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py +++ b/google/cloud/bigtable/data/_sync_autogen/_mutate_rows.py @@ -21,7 +21,6 @@ from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries import google.cloud.bigtable.data.exceptions as bt_exceptions -from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _retry_exception_factory from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT @@ -69,14 +68,10 @@ def __init__( raise ValueError( f"mutate_rows requests can contain at most {_MUTATE_ROWS_REQUEST_MUTATION_LIMIT} mutations across all entries. Found {total_mutations}." ) - metadata = _make_metadata( - table.table_name, table.app_profile_id, instance_name=None - ) self._gapic_fn = functools.partial( gapic_client.mutate_rows, table_name=table.table_name, app_profile_id=table.app_profile_id, - metadata=metadata, retry=None, ) self.is_retryable = retries.if_exception_type( diff --git a/google/cloud/bigtable/data/_sync_autogen/_read_rows.py b/google/cloud/bigtable/data/_sync_autogen/_read_rows.py index 271aa3fa2..708add7f0 100644 --- a/google/cloud/bigtable/data/_sync_autogen/_read_rows.py +++ b/google/cloud/bigtable/data/_sync_autogen/_read_rows.py @@ -28,7 +28,6 @@ from google.cloud.bigtable.data.exceptions import _RowSetComplete from google.cloud.bigtable.data.exceptions import _ResetRow from google.cloud.bigtable.data._helpers import _attempt_timeout_generator -from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory from google.api_core import retry as retries from google.api_core.retry import exponential_sleep_generator @@ -64,7 +63,6 @@ class _ReadRowsOperation: "request", "table", "_predicate", - "_metadata", "_last_yielded_row_key", "_remaining_count", ) @@ -91,9 +89,6 @@ def __init__( self.request = query._to_pb(table) self.table = table self._predicate = retries.if_exception_type(*retryable_exceptions) - self._metadata = _make_metadata( - table.table_name, table.app_profile_id, instance_name=None - ) self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None @@ -131,10 +126,7 @@ def _read_rows_attempt(self) -> CrossSync._Sync_Impl.Iterable[Row]: if self._remaining_count == 0: return self.merge_rows(None) gapic_stream = self.table.client._gapic_client.read_rows( - self.request, - timeout=next(self.attempt_timeout_gen), - metadata=self._metadata, - retry=None, + self.request, timeout=next(self.attempt_timeout_gen), retry=None ) chunked_stream = self.chunk_stream(gapic_stream) return self.merge_rows(chunked_stream) diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 37e1349bd..3cc8cdbf2 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -30,7 +30,6 @@ from google.cloud.bigtable.data.execute_query._parameters_formatting import ( _format_execute_query_params, ) -from google.cloud.bigtable_v2.services.bigtable.client import BigtableClientMeta from google.cloud.bigtable_v2.services.bigtable.transports.base import ( DEFAULT_CLIENT_INFO, ) @@ -52,7 +51,6 @@ from google.cloud.bigtable.data._helpers import TABLE_DEFAULT from google.cloud.bigtable.data._helpers import _WarmedInstanceKey from google.cloud.bigtable.data._helpers import _CONCURRENCY_LIMIT -from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _retry_exception_factory from google.cloud.bigtable.data._helpers import _validate_timeouts from google.cloud.bigtable.data._helpers import _get_error_type @@ -67,8 +65,9 @@ from google.cloud.bigtable.data.row_filters import RowFilterChain from google.cloud.bigtable.data._cross_sync import CrossSync from typing import Iterable -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledBigtableGrpcTransport as PooledTransportType, +from grpc import insecure_channel +from google.cloud.bigtable_v2.services.bigtable.transports import ( + BigtableGrpcTransport as TransportType, ) from google.cloud.bigtable.data._sync_autogen.mutations_batcher import ( MutationsBatcher, @@ -89,11 +88,11 @@ def __init__( self, *, project: str | None = None, - pool_size: int = 3, credentials: google.auth.credentials.Credentials | None = None, client_options: dict[str, Any] | "google.api_core.client_options.ClientOptions" | None = None, + **kwargs, ): """Create a client instance for the Bigtable Data API @@ -103,8 +102,6 @@ def __init__( project: the project which the client acts on behalf of. If not passed, falls back to the default inferred from the environment. - pool_size: The number of grpc channels to maintain - in the internal channel pool. credentials: Thehe OAuth2 Credentials to use for this client. If not passed (and if no ``_http`` object is @@ -114,11 +111,9 @@ def __init__( Client options used to set user options on the client. API Endpoint should be set through client_options. Raises: - ValueError: if pool_size is less than 1 """ - transport_str = f"bt-{self._client_version()}-{pool_size}" - transport = PooledTransportType.with_fixed_size(pool_size) - BigtableClientMeta._transport_registry[transport_str] = transport + if "pool_size" in kwargs: + warnings.warn("pool_size no longer supported") client_info = DEFAULT_CLIENT_INFO client_info.client_library_version = self._client_version() if type(client_options) is dict: @@ -126,8 +121,15 @@ def __init__( client_options = cast( Optional[client_options_lib.ClientOptions], client_options ) + custom_channel = None self._emulator_host = os.getenv(BIGTABLE_EMULATOR) if self._emulator_host is not None: + warnings.warn( + "Connecting to Bigtable emulator at {}".format(self._emulator_host), + RuntimeWarning, + stacklevel=2, + ) + custom_channel = insecure_channel(self._emulator_host) if credentials is None: credentials = google.auth.credentials.AnonymousCredentials() if project is None: @@ -139,34 +141,25 @@ def __init__( client_options=client_options, ) self._gapic_client = CrossSync._Sync_Impl.GapicClient( - transport=transport_str, credentials=credentials, client_options=client_options, client_info=client_info, + transport=lambda *args, **kwargs: TransportType( + *args, **kwargs, channel=custom_channel + ), ) self._is_closed = CrossSync._Sync_Impl.Event() - self.transport = cast(PooledTransportType, self._gapic_client.transport) + self.transport = cast(TransportType, self._gapic_client.transport) self._active_instances: Set[_WarmedInstanceKey] = set() self._instance_owners: dict[_WarmedInstanceKey, Set[int]] = {} self._channel_init_time = time.monotonic() - self._channel_refresh_tasks: list[CrossSync._Sync_Impl.Task[None]] = [] + self._channel_refresh_task: CrossSync._Sync_Impl.Task[None] | None = None self._executor = ( concurrent.futures.ThreadPoolExecutor() if not CrossSync._Sync_Impl.is_async else None ) - if self._emulator_host is not None: - warnings.warn( - "Connecting to Bigtable emulator at {}".format(self._emulator_host), - RuntimeWarning, - stacklevel=2, - ) - self.transport._grpc_channel = CrossSync._Sync_Impl.PooledChannel( - pool_size=pool_size, host=self._emulator_host, insecure=True - ) - self.transport._stubs = {} - self.transport._prep_wrapped_messages(client_info) - else: + if self._emulator_host is None: try: self._start_background_channel_refresh() except RuntimeError: @@ -183,49 +176,49 @@ def _client_version() -> str: return version_str def _start_background_channel_refresh(self) -> None: - """Starts a background task to ping and warm each channel in the pool + """Starts a background task to ping and warm grpc channel Raises: None""" if ( - not self._channel_refresh_tasks + not self._channel_refresh_task and (not self._emulator_host) and (not self._is_closed.is_set()) ): CrossSync._Sync_Impl.verify_async_event_loop() - for channel_idx in range(self.transport.pool_size): - refresh_task = CrossSync._Sync_Impl.create_task( - self._manage_channel, - channel_idx, - sync_executor=self._executor, - task_name=f"{self.__class__.__name__} channel refresh {channel_idx}", - ) - self._channel_refresh_tasks.append(refresh_task) + self._channel_refresh_task = CrossSync._Sync_Impl.create_task( + self._manage_channel, + sync_executor=self._executor, + task_name=f"{self.__class__.__name__} channel refresh", + ) def close(self, timeout: float | None = 2.0): """Cancel all background tasks""" self._is_closed.set() - for task in self._channel_refresh_tasks: - task.cancel() + if self._channel_refresh_task is not None: + self._channel_refresh_task.cancel() + CrossSync._Sync_Impl.wait([self._channel_refresh_task], timeout=timeout) self.transport.close() if self._executor: self._executor.shutdown(wait=False) - CrossSync._Sync_Impl.wait(self._channel_refresh_tasks, timeout=timeout) - self._channel_refresh_tasks = [] + self._channel_refresh_task = None def _ping_and_warm_instances( - self, channel: Channel, instance_key: _WarmedInstanceKey | None = None + self, + instance_key: _WarmedInstanceKey | None = None, + channel: Channel | None = None, ) -> list[BaseException | None]: """Prepares the backend for requests on a channel Pings each Bigtable instance registered in `_active_instances` on the client Args: - channel: grpc channel to warm instance_key: if provided, only warm the instance associated with the key + channel: grpc channel to warm. If none, warms `self.transport.grpc_channel` Returns: list[BaseException | None]: sequence of results or exceptions from the ping requests """ + channel = channel or self.transport.grpc_channel instance_list = ( [instance_key] if instance_key is not None else self._active_instances ) @@ -254,7 +247,6 @@ def _ping_and_warm_instances( def _manage_channel( self, - channel_idx: int, refresh_interval_min: float = 60 * 35, refresh_interval_max: float = 60 * 45, grace_period: float = 60 * 10, @@ -267,7 +259,6 @@ def _manage_channel( Runs continuously until the client is closed Args: - channel_idx: index of the channel in the transport's channel pool refresh_interval_min: minimum interval before initiating refresh process in seconds. Actual interval will be a random value between `refresh_interval_min` and `refresh_interval_max` @@ -281,32 +272,27 @@ def _manage_channel( ) next_sleep = max(first_refresh - time.monotonic(), 0) if next_sleep > 0: - channel = self.transport.channels[channel_idx] - self._ping_and_warm_instances(channel) + self._ping_and_warm_instances(channel=self.transport.grpc_channel) while not self._is_closed.is_set(): CrossSync._Sync_Impl.event_wait( self._is_closed, next_sleep, async_break_early=False ) if self._is_closed.is_set(): break - new_channel = self.transport.grpc_channel._create_channel() - self._ping_and_warm_instances(new_channel) start_timestamp = time.monotonic() - self.transport.replace_channel( - channel_idx, - grace=grace_period, - new_channel=new_channel, - event=self._is_closed, - ) + old_channel = self.transport.grpc_channel + new_channel = self.transport.create_channel() + self._ping_and_warm_instances(channel=new_channel) + self.transport._grpc_channel = new_channel + old_channel.close(grace_period) next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) - next_sleep = next_refresh - (time.monotonic() - start_timestamp) + next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) def _register_instance( self, instance_id: str, owner: Table | ExecuteQueryIterator ) -> None: - """Registers an instance with the client, and warms the channel pool - for the instance - The client will periodically refresh grpc channel pool used to make + """Registers an instance with the client, and warms the channel for the instance + The client will periodically refresh grpc channel used to make requests, and new channels will be warmed for each registered instance Channels will not be refreshed unless at least one instance is registered @@ -320,11 +306,10 @@ def _register_instance( instance_name, owner.table_name, owner.app_profile_id ) self._instance_owners.setdefault(instance_key, set()).add(id(owner)) - if instance_name not in self._active_instances: + if instance_key not in self._active_instances: self._active_instances.add(instance_key) - if self._channel_refresh_tasks: - for channel in self.transport.channels: - self._ping_and_warm_instances(channel, instance_key) + if self._channel_refresh_task: + self._ping_and_warm_instances(instance_key) else: self._start_background_channel_refresh() @@ -467,12 +452,6 @@ def execute_query( "params": pb_params, "proto_format": {}, } - app_profile_id_for_metadata = app_profile_id or "" - req_metadata = _make_metadata( - table_name=None, - app_profile_id=app_profile_id_for_metadata, - instance_name=instance_name, - ) return ExecuteQueryIterator( self, instance_id, @@ -480,8 +459,7 @@ def execute_query( request_body, attempt_timeout, operation_timeout, - req_metadata, - retryable_excs, + retryable_excs=retryable_excs, ) def __enter__(self): @@ -937,16 +915,12 @@ def sample_row_keys( retryable_excs = _get_retryable_errors(retryable_errors, self) predicate = retries.if_exception_type(*retryable_excs) sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - metadata = _make_metadata( - self.table_name, self.app_profile_id, instance_name=None - ) def execute_rpc(): results = self.client._gapic_client.sample_row_keys( table_name=self.table_name, app_profile_id=self.app_profile_id, timeout=next(attempt_timeout_gen), - metadata=metadata, retry=None, ) return [(s.row_key, s.offset_bytes) for s in results] @@ -1066,9 +1040,6 @@ def mutate_row( table_name=self.table_name, app_profile_id=self.app_profile_id, timeout=attempt_timeout, - metadata=_make_metadata( - self.table_name, self.app_profile_id, instance_name=None - ), retry=None, ) return CrossSync._Sync_Impl.retry_target( @@ -1181,9 +1152,6 @@ def check_and_mutate_row( ): false_case_mutations = [false_case_mutations] false_case_list = [m._to_pb() for m in false_case_mutations or []] - metadata = _make_metadata( - self.table_name, self.app_profile_id, instance_name=None - ) result = self.client._gapic_client.check_and_mutate_row( true_mutations=true_case_list, false_mutations=false_case_list, @@ -1191,7 +1159,6 @@ def check_and_mutate_row( row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, table_name=self.table_name, app_profile_id=self.app_profile_id, - metadata=metadata, timeout=operation_timeout, retry=None, ) @@ -1232,15 +1199,11 @@ def read_modify_write_row( rules = [rules] if not rules: raise ValueError("rules must contain at least one item") - metadata = _make_metadata( - self.table_name, self.app_profile_id, instance_name=None - ) result = self.client._gapic_client.read_modify_write_row( rules=[rule._to_pb() for rule in rules], row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, table_name=self.table_name, app_profile_id=self.app_profile_id, - metadata=metadata, timeout=operation_timeout, retry=None, ) diff --git a/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py index aa560b91a..412ef2527 100644 --- a/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py @@ -16,7 +16,7 @@ # This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING +from typing import Any, Dict, Optional, Sequence, Tuple, TYPE_CHECKING from google.api_core import retry as retries from google.cloud.bigtable.data.execute_query._byte_cursor import _ByteCursor from google.cloud.bigtable.data._helpers import ( @@ -48,8 +48,8 @@ def __init__( request_body: Dict[str, Any], attempt_timeout: float | None, operation_timeout: float, - req_metadata: Sequence[Tuple[str, str]], - retryable_excs: List[type[Exception]], + req_metadata: Sequence[Tuple[str, str]] = (), + retryable_excs: Sequence[type[Exception]] = (), ) -> None: """Collects responses from ExecuteQuery requests and parses them into QueryResultRows. diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index 40ddc1dcf..859ed89c1 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -158,11 +158,7 @@ def test_ping_and_warm_gapic(self, client, table): ) def test_ping_and_warm(self, client, table): """Test ping and warm from handwritten client""" - try: - channel = client.transport._grpc_channel.pool[0] - except Exception: - channel = client.transport._grpc_channel - results = client._ping_and_warm_instances(channel) + results = client._ping_and_warm_instances() assert len(results) == 1 assert results[0] is None diff --git a/tests/unit/data/_sync_autogen/test__mutate_rows.py b/tests/unit/data/_sync_autogen/test__mutate_rows.py index b86bdb943..dddeec5a6 100644 --- a/tests/unit/data/_sync_autogen/test__mutate_rows.py +++ b/tests/unit/data/_sync_autogen/test__mutate_rows.py @@ -93,15 +93,10 @@ def test_ctor(self): instance._gapic_fn() assert client.mutate_rows.call_count == 1 inner_kwargs = client.mutate_rows.call_args[1] - assert len(inner_kwargs) == 4 + assert len(inner_kwargs) == 3 assert inner_kwargs["table_name"] == table.table_name assert inner_kwargs["app_profile_id"] == table.app_profile_id assert inner_kwargs["retry"] is None - metadata = inner_kwargs["metadata"] - assert len(metadata) == 1 - assert metadata[0][0] == "x-goog-request-params" - assert str(table.table_name) in metadata[0][1] - assert str(table.app_profile_id) in metadata[0][1] entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries] assert instance.mutations == entries_w_pb assert next(instance.timeout_generator) == attempt_timeout diff --git a/tests/unit/data/_sync_autogen/test__read_rows.py b/tests/unit/data/_sync_autogen/test__read_rows.py index 671102ce5..25c209d6e 100644 --- a/tests/unit/data/_sync_autogen/test__read_rows.py +++ b/tests/unit/data/_sync_autogen/test__read_rows.py @@ -72,12 +72,6 @@ def test_ctor(self): assert instance._remaining_count == row_limit assert instance.operation_timeout == expected_operation_timeout assert client.read_rows.call_count == 0 - assert instance._metadata == [ - ( - "x-goog-request-params", - "table_name=test_table&app_profile_id=test_profile", - ) - ] assert instance.request.table_name == table.table_name assert instance.request.app_profile_id == table.app_profile_id assert instance.request.rows_limit == row_limit diff --git a/tests/unit/data/_sync_autogen/test_client.py b/tests/unit/data/_sync_autogen/test_client.py index d31a448ab..7d685fdf0 100644 --- a/tests/unit/data/_sync_autogen/test_client.py +++ b/tests/unit/data/_sync_autogen/test_client.py @@ -61,34 +61,26 @@ def _make_client(cls, *args, use_emulator=True, **kwargs): def test_ctor(self): expected_project = "project-id" - expected_pool_size = 11 expected_credentials = AnonymousCredentials() client = self._make_client( - project="project-id", - pool_size=expected_pool_size, - credentials=expected_credentials, - use_emulator=False, + project="project-id", credentials=expected_credentials, use_emulator=False ) CrossSync._Sync_Impl.yield_to_event_loop() assert client.project == expected_project - assert len(client.transport._grpc_channel._pool) == expected_pool_size assert not client._active_instances - assert len(client._channel_refresh_tasks) == expected_pool_size + assert client._channel_refresh_task is not None assert client.transport._credentials == expected_credentials client.close() def test_ctor_super_inits(self): from google.cloud.client import ClientWithProject from google.api_core import client_options as client_options_lib - from google.cloud.bigtable import __version__ as bigtable_version project = "project-id" - pool_size = 11 credentials = AnonymousCredentials() client_options = {"api_endpoint": "foo.bar:1234"} options_parsed = client_options_lib.from_dict(client_options) asyncio_portion = "-async" if CrossSync._Sync_Impl.is_async else "" - transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}" with mock.patch.object( CrossSync._Sync_Impl.GapicClient, "__init__" ) as bigtable_client_init: @@ -100,7 +92,6 @@ def test_ctor_super_inits(self): try: self._make_client( project=project, - pool_size=pool_size, credentials=credentials, client_options=options_parsed, use_emulator=False, @@ -109,7 +100,6 @@ def test_ctor_super_inits(self): pass assert bigtable_client_init.call_count == 1 kwargs = bigtable_client_init.call_args[1] - assert kwargs["transport"] == transport_str assert kwargs["credentials"] == credentials assert kwargs["client_options"] == options_parsed assert client_project_init.call_count == 1 @@ -166,100 +156,23 @@ def test_veneer_grpc_headers(self): ), f"'{wrapped_user_agent_sorted}' does not match {VENEER_HEADER_REGEX}" client.close() - def test_channel_pool_creation(self): - pool_size = 14 - with mock.patch.object( - CrossSync._Sync_Impl.grpc_helpers, - "create_channel", - CrossSync._Sync_Impl.Mock(), - ) as create_channel: - client = self._make_client(project="project-id", pool_size=pool_size) - assert create_channel.call_count == pool_size - client.close() - client = self._make_client(project="project-id", pool_size=pool_size) - pool_list = list(client.transport._grpc_channel._pool) - pool_set = set(client.transport._grpc_channel._pool) - assert len(pool_list) == len(pool_set) - client.close() - - def test_channel_pool_rotation(self): - pool_size = 7 - with mock.patch.object( - CrossSync._Sync_Impl.PooledChannel, "next_channel" - ) as next_channel: - client = self._make_client(project="project-id", pool_size=pool_size) - assert len(client.transport._grpc_channel._pool) == pool_size - next_channel.reset_mock() - with mock.patch.object( - type(client.transport._grpc_channel._pool[0]), "unary_unary" - ) as unary_unary: - channel_next = None - for i in range(pool_size): - channel_last = channel_next - channel_next = client.transport.grpc_channel._pool[i] - assert channel_last != channel_next - next_channel.return_value = channel_next - client.transport.ping_and_warm() - assert next_channel.call_count == i + 1 - unary_unary.assert_called_once() - unary_unary.reset_mock() - client.close() - - def test_channel_pool_replace(self): - import time - - sleep_module = asyncio if CrossSync._Sync_Impl.is_async else time - with mock.patch.object(sleep_module, "sleep"): - pool_size = 7 - client = self._make_client(project="project-id", pool_size=pool_size) - for replace_idx in range(pool_size): - start_pool = [ - channel for channel in client.transport._grpc_channel._pool - ] - grace_period = 9 - with mock.patch.object( - type(client.transport._grpc_channel._pool[-1]), "close" - ) as close: - new_channel = client.transport.create_channel() - client.transport.replace_channel( - replace_idx, grace=grace_period, new_channel=new_channel - ) - close.assert_called_once() - assert client.transport._grpc_channel._pool[replace_idx] == new_channel - for i in range(pool_size): - if i != replace_idx: - assert client.transport._grpc_channel._pool[i] == start_pool[i] - else: - assert client.transport._grpc_channel._pool[i] != start_pool[i] - client.close() - - def test__start_background_channel_refresh_tasks_exist(self): + def test__start_background_channel_refresh_task_exists(self): client = self._make_client(project="project-id", use_emulator=False) - assert len(client._channel_refresh_tasks) > 0 + assert client._channel_refresh_task is not None with mock.patch.object(asyncio, "create_task") as create_task: client._start_background_channel_refresh() create_task.assert_not_called() client.close() - @pytest.mark.parametrize("pool_size", [1, 3, 7]) - def test__start_background_channel_refresh(self, pool_size): - import concurrent.futures - - with mock.patch.object( - self._get_target_class(), - "_ping_and_warm_instances", - CrossSync._Sync_Impl.Mock(), - ) as ping_and_warm: - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - client._start_background_channel_refresh() - assert len(client._channel_refresh_tasks) == pool_size - for task in client._channel_refresh_tasks: - assert isinstance(task, concurrent.futures.Future) - assert ping_and_warm.call_count == pool_size - for channel in client.transport._grpc_channel._pool: - ping_and_warm.assert_any_call(channel) + def test__start_background_channel_refresh(self): + client = self._make_client(project="project-id", use_emulator=False) + ping_and_warm = CrossSync._Sync_Impl.Mock() + client._ping_and_warm_instances = ping_and_warm + client._start_background_channel_refresh() + assert client._channel_refresh_task is not None + assert isinstance(client._channel_refresh_task, asyncio.Task) + asyncio.sleep(0.1) + assert ping_and_warm.call_count == 1 client.close() def test__ping_and_warm_instances(self): @@ -277,7 +190,7 @@ def test__ping_and_warm_instances(self): channel = mock.Mock() client_mock._active_instances = [] result = self._get_target_class()._ping_and_warm_instances( - client_mock, channel + client_mock, channel=channel ) assert len(result) == 0 assert gather.call_args[1]["return_exceptions"] is True @@ -288,7 +201,7 @@ def test__ping_and_warm_instances(self): gather.reset_mock() channel.reset_mock() result = self._get_target_class()._ping_and_warm_instances( - client_mock, channel + client_mock, channel=channel ) assert len(result) == 4 gather.assert_called_once() @@ -328,10 +241,12 @@ def test__ping_and_warm_single_instance(self): client_mock._active_instances = [mock.Mock()] * 100 test_key = ("test-instance", "test-table", "test-app-profile") result = self._get_target_class()._ping_and_warm_instances( - client_mock, channel, test_key + client_mock, test_key ) assert len(result) == 1 - grpc_call_args = channel.unary_unary().call_args_list + grpc_call_args = ( + client_mock.transport.grpc_channel.unary_unary().call_args_list + ) assert len(grpc_call_args) == 1 kwargs = grpc_call_args[0][1] request = kwargs["request"] @@ -360,7 +275,7 @@ def test__manage_channel_first_sleep( try: client = self._make_client(project="project-id") client._channel_init_time = -wait_time - client._manage_channel(0, refresh_interval, refresh_interval) + client._manage_channel(refresh_interval, refresh_interval) except asyncio.CancelledError: pass sleep.assert_called_once() @@ -378,38 +293,26 @@ def test__manage_channel_ping_and_warm(self): client_mock = mock.Mock() client_mock._is_closed.is_set.return_value = False client_mock._channel_init_time = time.monotonic() - channel_list = [mock.Mock(), mock.Mock()] - client_mock.transport.channels = channel_list - new_channel = mock.Mock() - client_mock.transport.grpc_channel._create_channel.return_value = new_channel + orig_channel = client_mock.transport.grpc_channel sleep_tuple = ( (asyncio, "sleep") if CrossSync._Sync_Impl.is_async else (threading.Event, "wait") ) with mock.patch.object(*sleep_tuple): - client_mock.transport.replace_channel.side_effect = asyncio.CancelledError + orig_channel.close.side_effect = asyncio.CancelledError ping_and_warm = ( client_mock._ping_and_warm_instances ) = CrossSync._Sync_Impl.Mock() try: - channel_idx = 1 - self._get_target_class()._manage_channel(client_mock, channel_idx, 10) + self._get_target_class()._manage_channel(client_mock, 10) except asyncio.CancelledError: pass assert ping_and_warm.call_count == 2 - assert client_mock.transport.replace_channel.call_count == 1 - old_channel = channel_list[channel_idx] - assert old_channel != new_channel - called_with = [call[0][0] for call in ping_and_warm.call_args_list] - assert old_channel in called_with - assert new_channel in called_with - ping_and_warm.reset_mock() - try: - self._get_target_class()._manage_channel(client_mock, 0, 0, 0) - except asyncio.CancelledError: - pass - ping_and_warm.assert_called_once_with(new_channel) + assert client_mock.transport._grpc_channel != orig_channel + called_with = [call[1]["channel"] for call in ping_and_warm.call_args_list] + assert orig_channel in called_with + assert client_mock.transport.grpc_channel in called_with @pytest.mark.parametrize( "refresh_interval, num_cycles, expected_sleep", @@ -420,7 +323,8 @@ def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sle import random import threading - channel_idx = 1 + channel = mock.Mock() + channel.close = mock.AsyncMock() with mock.patch.object(random, "uniform") as uniform: uniform.side_effect = lambda min_, max_: min_ with mock.patch.object(time, "time") as time_mock: @@ -435,14 +339,17 @@ def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sle asyncio.CancelledError ] client = self._make_client(project="project-id") - with mock.patch.object(client.transport, "replace_channel"): + client.transport._grpc_channel = channel + with mock.patch.object( + client.transport, "replace_channel", return_value=channel + ): try: if refresh_interval is not None: client._manage_channel( - channel_idx, refresh_interval, refresh_interval + refresh_interval, refresh_interval ) else: - client._manage_channel(channel_idx) + client._manage_channel() except asyncio.CancelledError: pass assert sleep.call_count == num_cycles @@ -468,76 +375,53 @@ def test__manage_channel_random(self): uniform.return_value = 0 try: uniform.side_effect = asyncio.CancelledError - client = self._make_client(project="project-id", pool_size=1) + client = self._make_client(project="project-id") except asyncio.CancelledError: uniform.side_effect = None uniform.reset_mock() sleep.reset_mock() - min_val = 200 - max_val = 205 - uniform.side_effect = lambda min_, max_: min_ - sleep.side_effect = [None, None, asyncio.CancelledError] - try: - with mock.patch.object(client.transport, "replace_channel"): - client._manage_channel(0, min_val, max_val) - except asyncio.CancelledError: - pass - assert uniform.call_count == 3 - uniform_args = [call[0] for call in uniform.call_args_list] - for found_min, found_max in uniform_args: - assert found_min == min_val - assert found_max == max_val + with mock.patch.object(client.transport, "create_channel"): + min_val = 200 + max_val = 205 + uniform.side_effect = lambda min_, max_: min_ + sleep.side_effect = [None, asyncio.CancelledError] + try: + client._manage_channel(min_val, max_val) + except asyncio.CancelledError: + pass + assert uniform.call_count == 2 + uniform_args = [call[0] for call in uniform.call_args_list] + for found_min, found_max in uniform_args: + assert found_min == min_val + assert found_max == max_val @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) def test__manage_channel_refresh(self, num_cycles): - import threading - expected_grace = 9 expected_refresh = 0.5 - channel_idx = 1 grpc_lib = grpc.aio if CrossSync._Sync_Impl.is_async else grpc new_channel = grpc_lib.insecure_channel("localhost:8080") - with mock.patch.object( - CrossSync._Sync_Impl.PooledTransport, "replace_channel" - ) as replace_channel: - sleep_tuple = ( - (asyncio, "sleep") - if CrossSync._Sync_Impl.is_async - else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: - sleep.side_effect = [None for i in range(num_cycles)] + [ - asyncio.CancelledError - ] - with mock.patch.object( - CrossSync._Sync_Impl.grpc_helpers, "create_channel" - ) as create_channel: - create_channel.return_value = new_channel - with mock.patch.object( - self._get_target_class(), "_start_background_channel_refresh" - ): - client = self._make_client( - project="project-id", use_emulator=False - ) - create_channel.reset_mock() - try: - client._manage_channel( - channel_idx, - refresh_interval_min=expected_refresh, - refresh_interval_max=expected_refresh, - grace_period=expected_grace, - ) - except asyncio.CancelledError: - pass - assert sleep.call_count == num_cycles + 1 - assert create_channel.call_count == num_cycles - assert replace_channel.call_count == num_cycles - for call in replace_channel.call_args_list: - (args, kwargs) = call - assert args[0] == channel_idx - assert kwargs["grace"] == expected_grace - assert kwargs["new_channel"] == new_channel - client.close() + with mock.patch.object(CrossSync._Sync_Impl, "event_wait") as sleep: + sleep.side_effect = [None for i in range(num_cycles)] + [ + asyncio.CancelledError + ] + with mock.patch.object( + CrossSync._Sync_Impl.grpc_helpers, "create_channel" + ) as create_channel: + create_channel.return_value = new_channel + client = self._make_client(project="project-id", use_emulator=False) + create_channel.reset_mock() + try: + client._manage_channel( + refresh_interval_min=expected_refresh, + refresh_interval_max=expected_refresh, + grace_period=expected_grace, + ) + except asyncio.CancelledError: + pass + assert sleep.call_count == num_cycles + 1 + assert create_channel.call_count == num_cycles + client.close() def test__register_instance(self): """test instance registration""" @@ -547,12 +431,7 @@ def test__register_instance(self): instance_owners = {} client_mock._active_instances = active_instances client_mock._instance_owners = instance_owners - client_mock._channel_refresh_tasks = [] - client_mock._start_background_channel_refresh.side_effect = ( - lambda: client_mock._channel_refresh_tasks.append(mock.Mock) - ) - mock_channels = [mock.Mock() for i in range(5)] - client_mock.transport.channels = mock_channels + client_mock._channel_refresh_task = None client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() table_mock = mock.Mock() self._get_target_class()._register_instance( @@ -568,18 +447,17 @@ def test__register_instance(self): assert expected_key == tuple(list(active_instances)[0]) assert len(instance_owners) == 1 assert expected_key == tuple(list(instance_owners)[0]) - assert client_mock._channel_refresh_tasks + client_mock._channel_refresh_task = mock.Mock() table_mock2 = mock.Mock() self._get_target_class()._register_instance( client_mock, "instance-2", table_mock2 ) assert client_mock._start_background_channel_refresh.call_count == 1 - assert client_mock._ping_and_warm_instances.call_count == len(mock_channels) - for channel in mock_channels: - assert channel in [ - call[0][0] - for call in client_mock._ping_and_warm_instances.call_args_list - ] + assert ( + client_mock._ping_and_warm_instances.call_args[0][0][0] + == "prefix/instance-2" + ) + assert client_mock._ping_and_warm_instances.call_count == 1 assert len(active_instances) == 2 assert len(instance_owners) == 2 expected_key2 = ( @@ -600,6 +478,41 @@ def test__register_instance(self): ] ) + def test__register_instance_duplicate(self): + """test double instance registration. Should be no-op""" + client_mock = mock.Mock() + client_mock._gapic_client.instance_path.side_effect = lambda a, b: f"prefix/{b}" + active_instances = set() + instance_owners = {} + client_mock._active_instances = active_instances + client_mock._instance_owners = instance_owners + client_mock._channel_refresh_task = object() + mock_channels = [mock.Mock()] + client_mock.transport.channels = mock_channels + client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() + table_mock = mock.Mock() + expected_key = ( + "prefix/instance-1", + table_mock.table_name, + table_mock.app_profile_id, + ) + self._get_target_class()._register_instance( + client_mock, "instance-1", table_mock + ) + assert len(active_instances) == 1 + assert expected_key == tuple(list(active_instances)[0]) + assert len(instance_owners) == 1 + assert expected_key == tuple(list(instance_owners)[0]) + assert client_mock._ping_and_warm_instances.call_count == 1 + self._get_target_class()._register_instance( + client_mock, "instance-1", table_mock + ) + assert len(active_instances) == 1 + assert expected_key == tuple(list(active_instances)[0]) + assert len(instance_owners) == 1 + assert expected_key == tuple(list(instance_owners)[0]) + assert client_mock._ping_and_warm_instances.call_count == 1 + @pytest.mark.parametrize( "insert_instances,expected_active,expected_owner_keys", [ @@ -623,12 +536,7 @@ def test__register_instance_state( instance_owners = {} client_mock._active_instances = active_instances client_mock._instance_owners = instance_owners - client_mock._channel_refresh_tasks = [] - client_mock._start_background_channel_refresh.side_effect = ( - lambda: client_mock._channel_refresh_tasks.append(mock.Mock) - ) - mock_channels = [mock.Mock() for i in range(5)] - client_mock.transport.channels = mock_channels + client_mock._channel_refresh_task = None client_mock._ping_and_warm_instances = CrossSync._Sync_Impl.Mock() table_mock = mock.Mock() for instance, table, profile in insert_instances: @@ -866,60 +774,39 @@ def test_get_table_context_manager(self): assert client._instance_owners[instance_key] == {id(table)} assert close_mock.call_count == 1 - def test_multiple_pool_sizes(self): - pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256] - for pool_size in pool_sizes: - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - assert len(client._channel_refresh_tasks) == pool_size - client_duplicate = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - assert len(client_duplicate._channel_refresh_tasks) == pool_size - assert str(pool_size) in str(client.transport) - client.close() - client_duplicate.close() - def test_close(self): - pool_size = 7 - client = self._make_client( - project="project-id", pool_size=pool_size, use_emulator=False - ) - assert len(client._channel_refresh_tasks) == pool_size - tasks_list = list(client._channel_refresh_tasks) - for task in client._channel_refresh_tasks: - assert not task.done() + client = self._make_client(project="project-id", use_emulator=False) + task = client._channel_refresh_task + assert task is not None + assert not task.done() with mock.patch.object( - CrossSync._Sync_Impl.PooledTransport, "close", CrossSync._Sync_Impl.Mock() + client.transport, "close", CrossSync._Sync_Impl.Mock() ) as close_mock: client.close() close_mock.assert_called_once() - for task in tasks_list: - assert task.done() + close_mock.assert_awaited() + assert task.done() + assert task.cancelled() + assert client._channel_refresh_task is None def test_close_with_timeout(self): - pool_size = 7 expected_timeout = 19 - client = self._make_client(project="project-id", pool_size=pool_size) - tasks = list(client._channel_refresh_tasks) + client = self._make_client(project="project-id", use_emulator=False) with mock.patch.object( CrossSync._Sync_Impl, "wait", CrossSync._Sync_Impl.Mock() ) as wait_for_mock: client.close(timeout=expected_timeout) wait_for_mock.assert_called_once() assert wait_for_mock.call_args[1]["timeout"] == expected_timeout - client._channel_refresh_tasks = tasks client.close() def test_context_manager(self): close_mock = CrossSync._Sync_Impl.Mock() true_close = None - with self._make_client(project="project-id") as client: + with self._make_client(project="project-id", use_emulator=False) as client: true_close = client.close() client.close = close_mock - for task in client._channel_refresh_tasks: - assert not task.done() + assert not client._channel_refresh_task.done() assert client.project == "project-id" assert client._active_instances == set() close_mock.assert_not_called() @@ -1142,34 +1029,31 @@ def test_customizable_retryable_errors( ) @pytest.mark.parametrize("include_app_profile", [True, False]) def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): - """check that all requests attach proper metadata headers""" profile = "profile" if include_app_profile else None - with mock.patch.object( - CrossSync._Sync_Impl.GapicClient, gapic_fn, CrossSync._Sync_Impl.Mock() - ) as gapic_mock: - gapic_mock.side_effect = RuntimeError("stop early") - with self._make_client() as client: - table = self._get_target_class()( - client, "instance-id", "table-id", profile - ) - try: - test_fn = table.__getattribute__(fn_name) - maybe_stream = test_fn(*fn_args) - [i for i in maybe_stream] - except Exception: - pass - kwargs = gapic_mock.call_args_list[0][1] - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata + client = self._make_client() + transport_mock = mock.MagicMock() + rpc_mock = mock.AsyncMock() + transport_mock._wrapped_methods.__getitem__.return_value = rpc_mock + client._gapic_client._client._transport = transport_mock + client._gapic_client._client._is_universe_domain_valid = True + table = self._get_target_class()(client, "instance-id", "table-id", profile) + try: + test_fn = table.__getattribute__(fn_name) + maybe_stream = test_fn(*fn_args) + [i for i in maybe_stream] + except Exception: + pass + assert rpc_mock.call_count == 1 + kwargs = rpc_mock.call_args_list[0].kwargs + metadata = kwargs["metadata"] + assert len(metadata) == 1 + assert metadata[0][0] == "x-goog-request-params" + routing_str = metadata[0][1] + assert "table_name=" + table.table_name in routing_str + if include_app_profile: + assert "app_profile_id=profile" in routing_str + else: + assert "app_profile_id=" not in routing_str @CrossSync._Sync_Impl.add_mapping_decorator("TestReadRows") @@ -1902,11 +1786,10 @@ def test_sample_row_keys_gapic_params(self): table.sample_row_keys(attempt_timeout=expected_timeout) (args, kwargs) = sample_row_keys.call_args assert len(args) == 0 - assert len(kwargs) == 5 + assert len(kwargs) == 4 assert kwargs["timeout"] == expected_timeout assert kwargs["app_profile_id"] == expected_profile assert kwargs["table_name"] == table.table_name - assert kwargs["metadata"] is not None assert kwargs["retry"] is None @pytest.mark.parametrize( @@ -2078,29 +1961,6 @@ def test_mutate_row_non_retryable_errors(self, non_retryable_exception): assert mutation.is_idempotent() is True table.mutate_row("row_key", mutation, operation_timeout=0.2) - @pytest.mark.parametrize("include_app_profile", [True, False]) - def test_mutate_row_metadata(self, include_app_profile): - """request should attach metadata headers""" - profile = "profile" if include_app_profile else None - with self._make_client() as client: - with client.get_table("i", "t", app_profile_id=profile) as table: - with mock.patch.object( - client._gapic_client, "mutate_row", CrossSync._Sync_Impl.Mock() - ) as read_rows: - table.mutate_row("rk", mock.Mock()) - kwargs = read_rows.call_args_list[0][1] - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata - @pytest.mark.parametrize("mutations", [[], None]) def test_mutate_row_no_mutations(self, mutations): with self._make_client() as client: From 7a1e422a5bcef4db09e8b3875ce9ba1917866a00 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 11 Nov 2024 13:35:18 -0800 Subject: [PATCH 345/360] removed unneeded import --- google/cloud/bigtable/data/__init__.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index 68a315e50..e08993108 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -54,12 +54,6 @@ from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync from google.cloud.bigtable.data._async._mutate_rows import _MutateRowsOperationAsync -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledBigtableGrpcTransport, -) -from google.cloud.bigtable_v2.services.bigtable.transports.pooled_grpc import ( - PooledChannel, -) from google.cloud.bigtable_v2.services.bigtable.client import ( BigtableClient, ) From 008e7245fd09db8e86f3772f28df069228872128 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 11 Nov 2024 13:41:22 -0800 Subject: [PATCH 346/360] fix tests --- tests/unit/data/_async/test_client.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index a8f1ad42b..2e18d83ab 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -422,7 +422,7 @@ async def test__manage_channel_sleeps( import threading channel = mock.Mock() - channel.close = mock.AsyncMock() + channel.close = CrossSync.Mock() with mock.patch.object(random, "uniform") as uniform: uniform.side_effect = lambda min_, max_: min_ with mock.patch.object(time, "time") as time_mock: @@ -438,16 +438,15 @@ async def test__manage_channel_sleeps( ] client = self._make_client(project="project-id") client.transport._grpc_channel = channel - with mock.patch.object(client.transport, "replace_channel", return_value=channel): - try: - if refresh_interval is not None: - await client._manage_channel( - refresh_interval, refresh_interval - ) - else: - await client._manage_channel() - except asyncio.CancelledError: - pass + try: + if refresh_interval is not None: + await client._manage_channel( + refresh_interval, refresh_interval + ) + else: + await client._manage_channel() + except asyncio.CancelledError: + pass assert sleep.call_count == num_cycles if CrossSync.is_async: total_sleep = sum([call[0][0] for call in sleep.call_args_list]) @@ -1274,7 +1273,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ client = self._make_client() # create mock for rpc stub transport_mock = mock.MagicMock() - rpc_mock = mock.AsyncMock() + rpc_mock = CrossSync.Mock() transport_mock._wrapped_methods.__getitem__.return_value = rpc_mock client._gapic_client._client._transport = transport_mock client._gapic_client._client._is_universe_domain_valid = True From 2d04db77278e306f13cd0265ccc1e42094a8e2be Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 11 Nov 2024 15:41:31 -0800 Subject: [PATCH 347/360] fixed formatter param --- tests/unit/data/test_sync_up_to_date.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py index 7f3ef17a1..66d9a1619 100644 --- a/tests/unit/data/test_sync_up_to_date.py +++ b/tests/unit/data/test_sync_up_to_date.py @@ -56,10 +56,10 @@ def test_sync_up_to_date(sync_file): If this test fails, run `nox -s generate_sync` to update the sync files. """ path = sync_file.output_path - new_render = sync_file.render(with_black=True, save_to_disk=False) + new_render = sync_file.render(with_formatter=True, save_to_disk=False) found_render = CrossSyncOutputFile( output_path="", ast_tree=ast.parse(open(path).read()), header=sync_file.header - ).render(with_black=True, save_to_disk=False) + ).render(with_formatter=True, save_to_disk=False) # compare by content diff = unified_diff(found_render.splitlines(), new_render.splitlines(), lineterm="") diff_str = "\n".join(diff) From 7fb2134bc93be3c572bf237d95db655a70e07010 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 11 Nov 2024 15:42:04 -0800 Subject: [PATCH 348/360] fixed grace_period for sync client --- google/cloud/bigtable/data/_async/client.py | 10 ++- tests/unit/data/_async/test_client.py | 70 +++++++++------------ 2 files changed, 39 insertions(+), 41 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 689202267..a81178ea3 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -365,8 +365,14 @@ async def _manage_channel( await self._ping_and_warm_instances(channel=new_channel) # cycle channel out of use, with long grace window before closure self.transport._grpc_channel = new_channel - await old_channel.close(grace_period) - # subtract the time spent waiting for the channel to be replaced + # give old_channel a chance to complete existing rpcs + if CrossSync.is_async: + await old_channel.close(grace_period) + else: + if grace_period: + self._is_closed.wait(grace_period) + old_channel.close() + # subtract thed time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 2e18d83ab..1208d55a3 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -213,15 +213,15 @@ async def test__start_background_channel_refresh_task_exists(self): @CrossSync.pytest async def test__start_background_channel_refresh(self): # should create background tasks for each channel - client = self._make_client(project="project-id", use_emulator=False) - ping_and_warm = CrossSync.Mock() - client._ping_and_warm_instances = ping_and_warm - client._start_background_channel_refresh() - assert client._channel_refresh_task is not None - assert isinstance(client._channel_refresh_task, asyncio.Task) - await asyncio.sleep(0.1) - assert ping_and_warm.call_count == 1 - await client.close() + client = self._make_client(project="project-id") + with mock.patch.object(client, "_ping_and_warm_instances", CrossSync.Mock()) as ping_and_warm: + client._emulator_host = None + client._start_background_channel_refresh() + assert client._channel_refresh_task is not None + assert isinstance(client._channel_refresh_task, CrossSync.Task) + await CrossSync.sleep(0.1) + assert ping_and_warm.call_count == 1 + await client.close() @CrossSync.drop @CrossSync.pytest @@ -427,12 +427,7 @@ async def test__manage_channel_sleeps( uniform.side_effect = lambda min_, max_: min_ with mock.patch.object(time, "time") as time_mock: time_mock.return_value = 0 - sleep_tuple = ( - (asyncio, "sleep") - if CrossSync.is_async - else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: + with mock.patch.object(CrossSync, "event_wait") as sleep: sleep.side_effect = [None for i in range(num_cycles - 1)] + [ asyncio.CancelledError ] @@ -441,19 +436,14 @@ async def test__manage_channel_sleeps( try: if refresh_interval is not None: await client._manage_channel( - refresh_interval, refresh_interval + refresh_interval, refresh_interval, grace_period=0 ) else: - await client._manage_channel() + await client._manage_channel(grace_period=0) except asyncio.CancelledError: pass assert sleep.call_count == num_cycles - if CrossSync.is_async: - total_sleep = sum([call[0][0] for call in sleep.call_args_list]) - else: - total_sleep = sum( - [call[1]["timeout"] for call in sleep.call_args_list] - ) + total_sleep = sum([call[0][1] for call in sleep.call_args_list]) assert ( abs(total_sleep - expected_sleep) < 0.1 ), f"refresh_interval={refresh_interval}, num_cycles={num_cycles}, expected_sleep={expected_sleep}" @@ -464,10 +454,7 @@ async def test__manage_channel_random(self): import random import threading - sleep_tuple = ( - (asyncio, "sleep") if CrossSync.is_async else (threading.Event, "wait") - ) - with mock.patch.object(*sleep_tuple) as sleep: + with mock.patch.object(CrossSync, "event_wait") as sleep: with mock.patch.object(random, "uniform") as uniform: uniform.return_value = 0 try: @@ -483,7 +470,7 @@ async def test__manage_channel_random(self): uniform.side_effect = lambda min_, max_: min_ sleep.side_effect = [None, asyncio.CancelledError] try: - await client._manage_channel(min_val, max_val) + await client._manage_channel(min_val, max_val, grace_period=0) except asyncio.CancelledError: pass assert uniform.call_count == 2 @@ -496,28 +483,27 @@ async def test__manage_channel_random(self): @pytest.mark.parametrize("num_cycles", [0, 1, 10, 100]) async def test__manage_channel_refresh(self, num_cycles): # make sure that channels are properly refreshed - expected_grace = 9 expected_refresh = 0.5 grpc_lib = grpc.aio if CrossSync.is_async else grpc new_channel = grpc_lib.insecure_channel("localhost:8080") with mock.patch.object(CrossSync, "event_wait") as sleep: sleep.side_effect = [None for i in range(num_cycles)] + [ - asyncio.CancelledError + RuntimeError ] with mock.patch.object( CrossSync.grpc_helpers, "create_channel" ) as create_channel: create_channel.return_value = new_channel - client = self._make_client(project="project-id", use_emulator=False) + client = self._make_client(project="project-id") create_channel.reset_mock() try: await client._manage_channel( refresh_interval_min=expected_refresh, refresh_interval_max=expected_refresh, - grace_period=expected_grace, + grace_period=0, ) - except asyncio.CancelledError: + except RuntimeError: pass assert sleep.call_count == num_cycles + 1 assert create_channel.call_count == num_cycles @@ -935,9 +921,9 @@ async def test_close(self): with mock.patch.object(client.transport, "close", CrossSync.Mock()) as close_mock: await client.close() close_mock.assert_called_once() - close_mock.assert_awaited() + if CrossSync.is_async: + close_mock.assert_awaited() assert task.done() - assert task.cancelled() assert client._channel_refresh_task is None @CrossSync.pytest @@ -954,11 +940,13 @@ async def test_close_with_timeout(self): @CrossSync.pytest async def test_context_manager(self): + from functools import partial # context manager should close the client cleanly close_mock = CrossSync.Mock() true_close = None async with self._make_client(project="project-id", use_emulator=False) as client: - true_close = client.close() + # grab reference to close coro for async test + true_close = partial(client.close) client.close = close_mock assert not client._channel_refresh_task.done() assert client.project == "project-id" @@ -968,7 +956,7 @@ async def test_context_manager(self): if CrossSync.is_async: close_mock.assert_awaited() # actually close the client - await true_close + await true_close() @CrossSync.drop def test_client_ctor_sync(self): @@ -1275,8 +1263,12 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ transport_mock = mock.MagicMock() rpc_mock = CrossSync.Mock() transport_mock._wrapped_methods.__getitem__.return_value = rpc_mock - client._gapic_client._client._transport = transport_mock - client._gapic_client._client._is_universe_domain_valid = True + gapic_client = client._gapic_client + if CrossSync.is_async: + # inner BigtableClient is held as ._client for BigtableAsyncClient + gapic_client = gapic_client._client + gapic_client._transport = transport_mock + gapic_client._is_universe_domain_valid = True table = self._get_target_class()(client, "instance-id", "table-id", profile) try: test_fn = table.__getattribute__(fn_name) From e8d122e6dec3ce5d1dad5730ee2447bf9446fb5c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 11 Nov 2024 15:46:30 -0800 Subject: [PATCH 349/360] fixed lint --- google/cloud/bigtable/data/_async/client.py | 4 +++- tests/unit/data/_async/test_client.py | 25 ++++++++++----------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index a81178ea3..b56edcc3b 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -275,7 +275,9 @@ async def close(self, timeout: float | None = 2.0): @CrossSync.convert async def _ping_and_warm_instances( - self, instance_key: _WarmedInstanceKey | None = None, channel: Channel | None = None + self, + instance_key: _WarmedInstanceKey | None = None, + channel: Channel | None = None, ) -> list[BaseException | None]: """ Prepares the backend for requests on a channel diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 1208d55a3..268e77c19 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -99,13 +99,11 @@ async def test_ctor(self): async def test_ctor_super_inits(self): from google.cloud.client import ClientWithProject from google.api_core import client_options as client_options_lib - from google.cloud.bigtable import __version__ as bigtable_version project = "project-id" credentials = AnonymousCredentials() client_options = {"api_endpoint": "foo.bar:1234"} options_parsed = client_options_lib.from_dict(client_options) - asyncio_portion = "-async" if CrossSync.is_async else "" with mock.patch.object( CrossSync.GapicClient, "__init__" ) as bigtable_client_init: @@ -214,7 +212,9 @@ async def test__start_background_channel_refresh_task_exists(self): async def test__start_background_channel_refresh(self): # should create background tasks for each channel client = self._make_client(project="project-id") - with mock.patch.object(client, "_ping_and_warm_instances", CrossSync.Mock()) as ping_and_warm: + with mock.patch.object( + client, "_ping_and_warm_instances", CrossSync.Mock() + ) as ping_and_warm: client._emulator_host = None client._start_background_channel_refresh() assert client._channel_refresh_task is not None @@ -310,7 +310,6 @@ async def test__ping_and_warm_single_instance(self): CrossSync, "gather_partials", CrossSync.Mock() ) as gather: gather.side_effect = lambda *args, **kwargs: [fn() for fn in args[0]] - channel = mock.Mock() # test with large set of instances client_mock._active_instances = [mock.Mock()] * 100 test_key = ("test-instance", "test-table", "test-app-profile") @@ -419,7 +418,6 @@ async def test__manage_channel_sleeps( # make sure that sleeps work as expected import time import random - import threading channel = mock.Mock() channel.close = CrossSync.Mock() @@ -452,7 +450,6 @@ async def test__manage_channel_sleeps( @CrossSync.pytest async def test__manage_channel_random(self): import random - import threading with mock.patch.object(CrossSync, "event_wait") as sleep: with mock.patch.object(random, "uniform") as uniform: @@ -488,9 +485,7 @@ async def test__manage_channel_refresh(self, num_cycles): new_channel = grpc_lib.insecure_channel("localhost:8080") with mock.patch.object(CrossSync, "event_wait") as sleep: - sleep.side_effect = [None for i in range(num_cycles)] + [ - RuntimeError - ] + sleep.side_effect = [None for i in range(num_cycles)] + [RuntimeError] with mock.patch.object( CrossSync.grpc_helpers, "create_channel" ) as create_channel: @@ -918,7 +913,9 @@ async def test_close(self): task = client._channel_refresh_task assert task is not None assert not task.done() - with mock.patch.object(client.transport, "close", CrossSync.Mock()) as close_mock: + with mock.patch.object( + client.transport, "close", CrossSync.Mock() + ) as close_mock: await client.close() close_mock.assert_called_once() if CrossSync.is_async: @@ -941,10 +938,13 @@ async def test_close_with_timeout(self): @CrossSync.pytest async def test_context_manager(self): from functools import partial + # context manager should close the client cleanly close_mock = CrossSync.Mock() true_close = None - async with self._make_client(project="project-id", use_emulator=False) as client: + async with self._make_client( + project="project-id", use_emulator=False + ) as client: # grab reference to close coro for async test true_close = partial(client.close) client.close = close_mock @@ -1255,8 +1255,6 @@ async def test_customizable_retryable_errors( @CrossSync.pytest @CrossSync.convert async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): - from google.cloud.bigtable.data import TableAsync - profile = "profile" if include_app_profile else None client = self._make_client() # create mock for rpc stub @@ -1291,6 +1289,7 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ else: assert "app_profile_id=" not in routing_str + @CrossSync.convert_class( "TestReadRows", add_mapping_for_name="TestReadRows", From d489ad3389e0231583cc1d1045ca05ec3e16517d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 11 Nov 2024 16:02:55 -0800 Subject: [PATCH 350/360] cleaned up imports --- google/cloud/bigtable/data/__init__.py | 1 + google/cloud/bigtable/data/_async/client.py | 35 ++++++++----------- .../bigtable/data/_async/mutations_batcher.py | 4 +-- .../_async/execute_query_iterator.py | 2 -- 4 files changed, 17 insertions(+), 25 deletions(-) diff --git a/google/cloud/bigtable/data/__init__.py b/google/cloud/bigtable/data/__init__.py index 23d24a4f7..ec0dd24fb 100644 --- a/google/cloud/bigtable/data/__init__.py +++ b/google/cloud/bigtable/data/__init__.py @@ -61,6 +61,7 @@ CrossSync.add_mapping("_ReadRowsOperation", _ReadRowsOperationAsync) CrossSync.add_mapping("_MutateRowsOperation", _MutateRowsOperationAsync) CrossSync.add_mapping("ExecuteQueryIterator", ExecuteQueryIteratorAsync) +CrossSync.add_mapping("MutationsBatcher", MutationsBatcherAsync) __version__: str = package_version.__version__ diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index b56edcc3b..d560d7e1e 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -83,30 +83,25 @@ from google.cloud.bigtable_v2.services.bigtable.transports import ( BigtableGrpcAsyncIOTransport as TransportType, ) - from google.cloud.bigtable.data._async.mutations_batcher import ( - MutationsBatcherAsync, - _MB_SIZE, - ) - from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( - ExecuteQueryIteratorAsync, - ) - + from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE else: from grpc import insecure_channel from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore - from google.cloud.bigtable.data._sync_autogen.mutations_batcher import ( # noqa: F401 - MutationsBatcher, - _MB_SIZE, - ) - from google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator import ( # noqa: F401 - ExecuteQueryIterator, - ) if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples from google.cloud.bigtable.data._helpers import ShardedQuery + if CrossSync.is_async: + from google.cloud.bigtable.data._async.mutations_batcher import ( + MutationsBatcherAsync, + ) + from google.cloud.bigtable.data.execute_query._async.execute_query_iterator import ( + ExecuteQueryIteratorAsync, + ) + + __CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen.client" @@ -372,8 +367,8 @@ async def _manage_channel( await old_channel.close(grace_period) else: if grace_period: - self._is_closed.wait(grace_period) - old_channel.close() + self._is_closed.wait(grace_period) # type: ignore + old_channel.close() # type: ignore # subtract thed time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) @@ -421,7 +416,7 @@ async def _register_instance( } ) async def _remove_instance_registration( - self, instance_id: str, owner: TableAsync | ExecuteQueryIteratorAsync + self, instance_id: str, owner: TableAsync | "ExecuteQueryIteratorAsync" ) -> bool: """ Removes an instance from the client's registered instances, to prevent @@ -585,7 +580,7 @@ async def execute_query( "proto_format": {}, } - return ExecuteQueryIteratorAsync( + return CrossSync.ExecuteQueryIterator( self, instance_id, app_profile_id, @@ -1132,7 +1127,7 @@ def mutations_batcher( batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, - ) -> MutationsBatcherAsync: + ) -> "MutationsBatcherAsync": """ Returns a new mutations batcher instance. diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index d8ecb7d32..65070c880 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -176,9 +176,7 @@ async def add_to_flow(self, mutations: RowMutationEntry | list[RowMutationEntry] yield mutations[start_idx:end_idx] -@CrossSync.convert_class( - sync_name="MutationsBatcher", add_mapping_for_name="MutationsBatcher" -) +@CrossSync.convert_class(sync_name="MutationsBatcher") class MutationsBatcherAsync: """ Allows users to send batches using context manager API: diff --git a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py index 40dc0676a..7125f64a3 100644 --- a/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_async/execute_query_iterator.py @@ -46,8 +46,6 @@ if TYPE_CHECKING: if CrossSync.is_async: from google.cloud.bigtable.data import BigtableDataClientAsync as DataClientType - else: - from google.cloud.bigtable.data import BigtableDataClient as DataClientType __CROSS_SYNC_OUTPUT__ = ( "google.cloud.bigtable.data.execute_query._sync_autogen.execute_query_iterator" From 0d7d7ea332fe84032d19357e947bd2f4667c7941 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 11 Nov 2024 16:14:16 -0800 Subject: [PATCH 351/360] fixed lint --- docs/scripts/patch_devsite_toc.py | 4 +--- google/cloud/bigtable/data/execute_query/__init__.py | 1 + tests/unit/data/test_sync_up_to_date.py | 9 +++++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/scripts/patch_devsite_toc.py b/docs/scripts/patch_devsite_toc.py index 78cc6fd39..c794a0a90 100644 --- a/docs/scripts/patch_devsite_toc.py +++ b/docs/scripts/patch_devsite_toc.py @@ -195,9 +195,7 @@ def validate_toc(toc_file_path, expected_section_list, added_sections): # Add secrtions for the async_data_client and classic_client directories toc_path = "_build/html/docfx_yaml/toc.yml" custom_sections = [ - TocSection( - dir_name="data_client", index_file_name="usage.rst" - ), + TocSection(dir_name="data_client", index_file_name="usage.rst"), TocSection(dir_name="classic_client", index_file_name="usage.rst"), ] add_sections(toc_path, custom_sections) diff --git a/google/cloud/bigtable/data/execute_query/__init__.py b/google/cloud/bigtable/data/execute_query/__init__.py index ac49355ae..31fd5e3cc 100644 --- a/google/cloud/bigtable/data/execute_query/__init__.py +++ b/google/cloud/bigtable/data/execute_query/__init__.py @@ -29,6 +29,7 @@ Struct, ) from google.cloud.bigtable.data._cross_sync import CrossSync + CrossSync.add_mapping("ExecuteQueryIterator", ExecuteQueryIteratorAsync) CrossSync._Sync_Impl.add_mapping("ExecuteQueryIterator", ExecuteQueryIterator) diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py index 66d9a1619..492d35ddf 100644 --- a/tests/unit/data/test_sync_up_to_date.py +++ b/tests/unit/data/test_sync_up_to_date.py @@ -29,6 +29,7 @@ sync_files = list(convert_files_in_dir(repo_root)) + def test_found_files(): """ Make sure sync_test is populated with some of the files we expect to see, @@ -42,7 +43,9 @@ def test_found_files(): assert "execute_query_iterator.py" in outputs assert "test_client.py" in outputs assert "test_system_autogen.py" in outputs, "system tests not found" - assert "client_handler_data_sync_autogen.py" in outputs, "test proxy handler not found" + assert ( + "client_handler_data_sync_autogen.py" in outputs + ), "test proxy handler not found" @pytest.mark.skipif( @@ -63,7 +66,9 @@ def test_sync_up_to_date(sync_file): # compare by content diff = unified_diff(found_render.splitlines(), new_render.splitlines(), lineterm="") diff_str = "\n".join(diff) - assert not diff_str, f"Found differences. Run `nox -s generate_sync` to update:\n{diff_str}" + assert ( + not diff_str + ), f"Found differences. Run `nox -s generate_sync` to update:\n{diff_str}" # compare by hash new_hash = hashlib.md5(new_render.encode()).hexdigest() found_hash = hashlib.md5(found_render.encode()).hexdigest() From 081a234072599a93d9ccb0c6d4a07900f5a685e4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 27 Nov 2024 16:22:58 -0800 Subject: [PATCH 352/360] updated sync code --- .../_sync_autogen/execute_query_iterator.py | 8 +++---- tests/system/data/test_system_autogen.py | 2 +- tests/unit/data/_sync_autogen/test_client.py | 21 +++++++++++-------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py b/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py index 412ef2527..854148ff3 100644 --- a/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py +++ b/google/cloud/bigtable/data/execute_query/_sync_autogen/execute_query_iterator.py @@ -59,13 +59,11 @@ def __init__( client: bigtable client instance_id: id of the instance on which the query is executed request_body: dict representing the body of the ExecuteQueryRequest - attempt_timeout: the time budget for the entire operation, in seconds. - Failed requests will be retried within the budget. - Defaults to 600 seconds. - operation_timeout: the time budget for an individual network request, in seconds. + attempt_timeout: the time budget for an individual network request, in seconds. If it takes longer than this time to complete, the request will be cancelled with a DeadlineExceeded exception, and a retry will be attempted. - Defaults to the 20 seconds. If None, defaults to operation_timeout. + operation_timeout: the time budget for the entire operation, in seconds. + Failed requests will be retried within the budget req_metadata: metadata used while sending the gRPC request retryable_excs: a list of errors that will be retried if encountered. Raises: diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index 859ed89c1..edc10679d 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -75,7 +75,7 @@ class TestSystem: @pytest.fixture(scope="session") def client(self): project = os.getenv("GOOGLE_CLOUD_PROJECT") or None - with CrossSync._Sync_Impl.DataClient(project=project, pool_size=4) as client: + with CrossSync._Sync_Impl.DataClient(project=project) as client: yield client @pytest.fixture(scope="session") diff --git a/tests/unit/data/_sync_autogen/test_client.py b/tests/unit/data/_sync_autogen/test_client.py index c0645adb4..62b6b548a 100644 --- a/tests/unit/data/_sync_autogen/test_client.py +++ b/tests/unit/data/_sync_autogen/test_client.py @@ -334,15 +334,18 @@ def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sle ] client = self._make_client(project="project-id") client.transport._grpc_channel = channel - try: - if refresh_interval is not None: - client._manage_channel( - refresh_interval, refresh_interval, grace_period=0 - ) - else: - client._manage_channel(grace_period=0) - except asyncio.CancelledError: - pass + with mock.patch.object( + client.transport, "create_channel", CrossSync._Sync_Impl.Mock + ): + try: + if refresh_interval is not None: + client._manage_channel( + refresh_interval, refresh_interval, grace_period=0 + ) + else: + client._manage_channel(grace_period=0) + except asyncio.CancelledError: + pass assert sleep.call_count == num_cycles total_sleep = sum([call[0][1] for call in sleep.call_args_list]) assert ( From d0768d3669360f64fb866d95699d693e11d33bfc Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 27 Nov 2024 16:41:20 -0800 Subject: [PATCH 353/360] fix docs issue --- docs/async_data_client/async_data_table.rst | 11 ----------- docs/data_client/async_table.rst | 7 ++++++- docs/data_client/mutations_batcher.rst | 2 +- 3 files changed, 7 insertions(+), 13 deletions(-) delete mode 100644 docs/async_data_client/async_data_table.rst diff --git a/docs/async_data_client/async_data_table.rst b/docs/async_data_client/async_data_table.rst deleted file mode 100644 index 3b7973e8e..000000000 --- a/docs/async_data_client/async_data_table.rst +++ /dev/null @@ -1,11 +0,0 @@ -Table Async -~~~~~~~~~~~ - - .. note:: - - It is generally not recommended to use the async client in an otherwise synchronous codebase. To make use of asyncio's - performance benefits, the codebase should be designed to be async from the ground up. - -.. autoclass:: google.cloud.bigtable.data._async.client.TableAsync - :members: - :show-inheritance: diff --git a/docs/data_client/async_table.rst b/docs/data_client/async_table.rst index 05ffb8fad..3b7973e8e 100644 --- a/docs/data_client/async_table.rst +++ b/docs/data_client/async_table.rst @@ -1,6 +1,11 @@ Table Async ~~~~~~~~~~~ -.. autoclass:: google.cloud.bigtable.data.TableAsync + .. note:: + + It is generally not recommended to use the async client in an otherwise synchronous codebase. To make use of asyncio's + performance benefits, the codebase should be designed to be async from the ground up. + +.. autoclass:: google.cloud.bigtable.data._async.client.TableAsync :members: :show-inheritance: diff --git a/docs/data_client/mutations_batcher.rst b/docs/data_client/mutations_batcher.rst index b21a193d1..2b7d1bfe0 100644 --- a/docs/data_client/mutations_batcher.rst +++ b/docs/data_client/mutations_batcher.rst @@ -1,6 +1,6 @@ Mutations Batcher ~~~~~~~~~~~~~~~~~ -.. automodule:: google.cloud.bigtable.data._sync.mutations_batcher +.. automodule:: google.cloud.bigtable.data._sync_autogen.mutations_batcher :members: :show-inheritance: From 61707ca874dbaa2785e4b321d903ebaf463ebbee Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 27 Nov 2024 16:53:16 -0800 Subject: [PATCH 354/360] remove new tests from this PR --- .cross_sync/README.md | 2 - .github/workflows/conformance.yaml | 12 +- .kokoro/conformance.sh | 10 +- noxfile.py | 2 +- test_proxy/README.md | 2 +- .../client_handler_data_sync_autogen.py | 185 ------------------ test_proxy/run_tests.sh | 16 +- test_proxy/test_proxy.py | 5 +- tests/unit/data/test_sync_up_to_date.py | 99 ---------- 9 files changed, 15 insertions(+), 318 deletions(-) delete mode 100644 test_proxy/handlers/client_handler_data_sync_autogen.py delete mode 100644 tests/unit/data/test_sync_up_to_date.py diff --git a/.cross_sync/README.md b/.cross_sync/README.md index 0d8a1cf8c..18a9aafdf 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -66,8 +66,6 @@ Generation can be initiated using `nox -s generate_sync` from the root of the project. This will find all classes with the `__CROSS_SYNC_OUTPUT__ = "path/to/output"` annotation, and generate a sync version of classes marked with `@CrossSync.convert_sync` at the output path. -There is a unit test at `tests/unit/data/test_sync_up_to_date.py` that verifies that the generated code is up to date - ## Architecture CrossSync is made up of two parts: diff --git a/.github/workflows/conformance.yaml b/.github/workflows/conformance.yaml index d4e992c8d..448e1cc3a 100644 --- a/.github/workflows/conformance.yaml +++ b/.github/workflows/conformance.yaml @@ -26,15 +26,7 @@ jobs: matrix: test-version: [ "v0.0.2" ] py-version: [ 3.8 ] - client-type: [ "async", "sync", "legacy" ] - include: - - client-type: "sync" - # sync client does not support concurrent streams - test_args: "-skip _Generic_MultiStream" - - client-type: "legacy" - # legacy client is synchtonous and does not support concurrent streams - # legacy client does not expose mutate_row. Disable those tests - test_args: "-skip _Generic_MultiStream -skip TestMutateRow_" + client-type: [ "async", "legacy" ] fail-fast: false name: "${{ matrix.client-type }} client / python ${{ matrix.py-version }} / test tag ${{ matrix.test-version }}" steps: @@ -61,6 +53,4 @@ jobs: env: CLIENT_TYPE: ${{ matrix.client-type }} PYTHONUNBUFFERED: 1 - TEST_ARGS: ${{ matrix.test_args }} - PROXY_PORT: 9999 diff --git a/.kokoro/conformance.sh b/.kokoro/conformance.sh index fd585142e..e85fc1394 100644 --- a/.kokoro/conformance.sh +++ b/.kokoro/conformance.sh @@ -19,7 +19,16 @@ set -eo pipefail ## cd to the parent directory, i.e. the root of the git repo cd $(dirname $0)/.. +PROXY_ARGS="" +TEST_ARGS="" +if [[ "${CLIENT_TYPE^^}" == "LEGACY" ]]; then + echo "Using legacy client" + # legacy client does not expose mutate_row. Disable those tests + TEST_ARGS="-skip TestMutateRow_" +fi + # Build and start the proxy in a separate process +PROXY_PORT=9999 pushd test_proxy nohup python test_proxy.py --port $PROXY_PORT --client_type=$CLIENT_TYPE & proxyPID=$! @@ -33,7 +42,6 @@ function cleanup() { trap cleanup EXIT # Run the conformance test -echo "running tests with args: $TEST_ARGS" pushd cloud-bigtable-clients-test/tests eval "go test -v -proxy_addr=:$PROXY_PORT $TEST_ARGS" RETURN_CODE=$? diff --git a/noxfile.py b/noxfile.py index 548bfd0ec..8576fed85 100644 --- a/noxfile.py +++ b/noxfile.py @@ -298,7 +298,7 @@ def system_emulated(session): @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) -@nox.parametrize("client_type", ["async", "sync", "legacy"]) +@nox.parametrize("client_type", ["async"]) def conformance(session, client_type): # install dependencies constraints_path = str( diff --git a/test_proxy/README.md b/test_proxy/README.md index 5c87c729a..e46ed232e 100644 --- a/test_proxy/README.md +++ b/test_proxy/README.md @@ -31,7 +31,7 @@ python test_proxy.py --port 8080 ``` By default, the test_proxy targets the async client. You can change this by passing in the `--client_type` flag. -Valid options are `async`, `sync`, and `legacy`. +Valid options are `async`, and `legacy`. ``` python test_proxy.py --client_type=legacy diff --git a/test_proxy/handlers/client_handler_data_sync_autogen.py b/test_proxy/handlers/client_handler_data_sync_autogen.py deleted file mode 100644 index 52ddec6fd..000000000 --- a/test_proxy/handlers/client_handler_data_sync_autogen.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This file is automatically generated by CrossSync. Do not edit manually. - -""" -This module contains the client handler process for proxy_server.py. -""" -import os -from google.cloud.environment_vars import BIGTABLE_EMULATOR -from google.cloud.bigtable.data._cross_sync import CrossSync -from client_handler_data_async import error_safe - - -class TestProxyClientHandler: - """ - Implements the same methods as the grpc server, but handles the client - library side of the request. - - Requests received in TestProxyGrpcServer are converted to a dictionary, - and supplied to the TestProxyClientHandler methods as kwargs. - The client response is then returned back to the TestProxyGrpcServer - """ - - def __init__( - self, - data_target=None, - project_id=None, - instance_id=None, - app_profile_id=None, - per_operation_timeout=None, - **kwargs - ): - self.closed = False - os.environ[BIGTABLE_EMULATOR] = data_target - self.client = CrossSync._Sync_Impl.DataClient(project=project_id) - self.instance_id = instance_id - self.app_profile_id = app_profile_id - self.per_operation_timeout = per_operation_timeout - - def close(self): - self.closed = True - - @error_safe - async def ReadRows(self, request, **kwargs): - table_id = request.pop("table_name").split("/")[-1] - app_profile_id = self.app_profile_id or request.get("app_profile_id", None) - table = self.client.get_table(self.instance_id, table_id, app_profile_id) - kwargs["operation_timeout"] = ( - kwargs.get("operation_timeout", self.per_operation_timeout) or 20 - ) - result_list = table.read_rows(request, **kwargs) - serialized_response = [row._to_dict() for row in result_list] - return serialized_response - - @error_safe - async def ReadRow(self, row_key, **kwargs): - table_id = kwargs.pop("table_name").split("/")[-1] - app_profile_id = self.app_profile_id or kwargs.get("app_profile_id", None) - table = self.client.get_table(self.instance_id, table_id, app_profile_id) - kwargs["operation_timeout"] = ( - kwargs.get("operation_timeout", self.per_operation_timeout) or 20 - ) - result_row = table.read_row(row_key, **kwargs) - if result_row: - return result_row._to_dict() - else: - return "None" - - @error_safe - async def MutateRow(self, request, **kwargs): - from google.cloud.bigtable.data.mutations import Mutation - - table_id = request["table_name"].split("/")[-1] - app_profile_id = self.app_profile_id or request.get("app_profile_id", None) - table = self.client.get_table(self.instance_id, table_id, app_profile_id) - kwargs["operation_timeout"] = ( - kwargs.get("operation_timeout", self.per_operation_timeout) or 20 - ) - row_key = request["row_key"] - mutations = [Mutation._from_dict(d) for d in request["mutations"]] - table.mutate_row(row_key, mutations, **kwargs) - return "OK" - - @error_safe - async def BulkMutateRows(self, request, **kwargs): - from google.cloud.bigtable.data.mutations import RowMutationEntry - - table_id = request["table_name"].split("/")[-1] - app_profile_id = self.app_profile_id or request.get("app_profile_id", None) - table = self.client.get_table(self.instance_id, table_id, app_profile_id) - kwargs["operation_timeout"] = ( - kwargs.get("operation_timeout", self.per_operation_timeout) or 20 - ) - entry_list = [ - RowMutationEntry._from_dict(entry) for entry in request["entries"] - ] - table.bulk_mutate_rows(entry_list, **kwargs) - return "OK" - - @error_safe - async def CheckAndMutateRow(self, request, **kwargs): - from google.cloud.bigtable.data.mutations import Mutation, SetCell - - table_id = request["table_name"].split("/")[-1] - app_profile_id = self.app_profile_id or request.get("app_profile_id", None) - table = self.client.get_table(self.instance_id, table_id, app_profile_id) - kwargs["operation_timeout"] = ( - kwargs.get("operation_timeout", self.per_operation_timeout) or 20 - ) - row_key = request["row_key"] - true_mutations = [] - for mut_dict in request.get("true_mutations", []): - try: - true_mutations.append(Mutation._from_dict(mut_dict)) - except ValueError: - mutation = SetCell("", "", "", 0) - true_mutations.append(mutation) - false_mutations = [] - for mut_dict in request.get("false_mutations", []): - try: - false_mutations.append(Mutation._from_dict(mut_dict)) - except ValueError: - false_mutations.append(SetCell("", "", "", 0)) - predicate_filter = request.get("predicate_filter", None) - result = table.check_and_mutate_row( - row_key, - predicate_filter, - true_case_mutations=true_mutations, - false_case_mutations=false_mutations, - **kwargs - ) - return result - - @error_safe - async def ReadModifyWriteRow(self, request, **kwargs): - from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule - from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule - - table_id = request["table_name"].split("/")[-1] - app_profile_id = self.app_profile_id or request.get("app_profile_id", None) - table = self.client.get_table(self.instance_id, table_id, app_profile_id) - kwargs["operation_timeout"] = ( - kwargs.get("operation_timeout", self.per_operation_timeout) or 20 - ) - row_key = request["row_key"] - rules = [] - for rule_dict in request.get("rules", []): - qualifier = rule_dict["column_qualifier"] - if "append_value" in rule_dict: - new_rule = AppendValueRule( - rule_dict["family_name"], qualifier, rule_dict["append_value"] - ) - else: - new_rule = IncrementRule( - rule_dict["family_name"], qualifier, rule_dict["increment_amount"] - ) - rules.append(new_rule) - result = table.read_modify_write_row(row_key, rules, **kwargs) - if result: - return result._to_dict() - else: - return "None" - - @error_safe - async def SampleRowKeys(self, request, **kwargs): - table_id = request["table_name"].split("/")[-1] - app_profile_id = self.app_profile_id or request.get("app_profile_id", None) - table = self.client.get_table(self.instance_id, table_id, app_profile_id) - kwargs["operation_timeout"] = ( - kwargs.get("operation_timeout", self.per_operation_timeout) or 20 - ) - result = table.sample_row_keys(**kwargs) - return result diff --git a/test_proxy/run_tests.sh b/test_proxy/run_tests.sh index b6f1291a6..68788e3bb 100755 --- a/test_proxy/run_tests.sh +++ b/test_proxy/run_tests.sh @@ -27,7 +27,7 @@ fi SCRIPT_DIR=$(realpath $(dirname "$0")) cd $SCRIPT_DIR -export PROXY_SERVER_PORT=$(shuf -i 50000-60000 -n 1) +export PROXY_SERVER_PORT=50055 # download test suite if [ ! -d "cloud-bigtable-clients-test" ]; then @@ -43,19 +43,7 @@ function finish { } trap finish EXIT -if [[ $CLIENT_TYPE == "legacy" ]]; then - echo "Using legacy client" - # legacy client does not expose mutate_row. Disable those tests - TEST_ARGS="-skip TestMutateRow_" -fi - -if [[ $CLIENT_TYPE != "async" ]]; then - echo "Using legacy client" - # sync and legacy client do not support concurrent streams - TEST_ARGS="$TEST_ARGS -skip _Generic_MultiStream " -fi - # run tests pushd cloud-bigtable-clients-test/tests echo "Running with $TEST_ARGS" -go test -v -proxy_addr=:$PROXY_SERVER_PORT $TEST_ARGS +go test -v -proxy_addr=:$PROXY_SERVER_PORT diff --git a/test_proxy/test_proxy.py b/test_proxy/test_proxy.py index 793500768..9e03f1e5c 100644 --- a/test_proxy/test_proxy.py +++ b/test_proxy/test_proxy.py @@ -114,9 +114,6 @@ def format_dict(input_obj): if client_type == "legacy": import client_handler_legacy client = client_handler_legacy.LegacyTestProxyClientHandler(**json_data) - elif client_type == "sync": - import client_handler_data_sync_autogen - client = client_handler_data_sync_autogen.TestProxyClientHandler(**json_data) else: client = client_handler_data_async.TestProxyClientHandlerAsync(**json_data) client_map[client_id] = client @@ -153,7 +150,7 @@ def client_handler_process(request_q, queue_pool, client_type="async"): p = argparse.ArgumentParser() p.add_argument("--port", dest='port', default="50055") -p.add_argument("--client_type", dest='client_type', default="async", choices=["async", "sync", "legacy"]) +p.add_argument("--client_type", dest='client_type', default="async", choices=["async", "legacy"]) if __name__ == "__main__": port = p.parse_args().port diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py deleted file mode 100644 index 492d35ddf..000000000 --- a/tests/unit/data/test_sync_up_to_date.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import sys -import hashlib -import pytest -import ast -import re -from difflib import unified_diff - -# add cross_sync to path -test_dir_name = os.path.dirname(__file__) -repo_root = os.path.join(test_dir_name, "..", "..", "..") -cross_sync_path = os.path.join(repo_root, ".cross_sync") -sys.path.append(cross_sync_path) - -from generate import convert_files_in_dir, CrossSyncOutputFile # noqa: E402 - -sync_files = list(convert_files_in_dir(repo_root)) - - -def test_found_files(): - """ - Make sure sync_test is populated with some of the files we expect to see, - to ensure that later tests are actually running. - """ - assert len(sync_files) > 0, "No sync files found" - assert len(sync_files) > 10, "Unexpectedly few sync files found" - # test for key files - outputs = [os.path.basename(f.output_path) for f in sync_files] - assert "client.py" in outputs - assert "execute_query_iterator.py" in outputs - assert "test_client.py" in outputs - assert "test_system_autogen.py" in outputs, "system tests not found" - assert ( - "client_handler_data_sync_autogen.py" in outputs - ), "test proxy handler not found" - - -@pytest.mark.skipif( - sys.version_info < (3, 9), reason="ast.unparse is only available in 3.9+" -) -@pytest.mark.parametrize("sync_file", sync_files, ids=lambda f: f.output_path) -def test_sync_up_to_date(sync_file): - """ - Generate a fresh copy of each cross_sync file, and compare hashes with the existing file. - - If this test fails, run `nox -s generate_sync` to update the sync files. - """ - path = sync_file.output_path - new_render = sync_file.render(with_formatter=True, save_to_disk=False) - found_render = CrossSyncOutputFile( - output_path="", ast_tree=ast.parse(open(path).read()), header=sync_file.header - ).render(with_formatter=True, save_to_disk=False) - # compare by content - diff = unified_diff(found_render.splitlines(), new_render.splitlines(), lineterm="") - diff_str = "\n".join(diff) - assert ( - not diff_str - ), f"Found differences. Run `nox -s generate_sync` to update:\n{diff_str}" - # compare by hash - new_hash = hashlib.md5(new_render.encode()).hexdigest() - found_hash = hashlib.md5(found_render.encode()).hexdigest() - assert new_hash == found_hash, f"md5 mismatch for {path}" - - -@pytest.mark.parametrize("sync_file", sync_files, ids=lambda f: f.output_path) -def test_verify_headers(sync_file): - license_regex = r""" - \#\ Copyright\ \d{4}\ Google\ LLC\n - \#\n - \#\ Licensed\ under\ the\ Apache\ License,\ Version\ 2\.0\ \(the\ \"License\"\);\n - \#\ you\ may\ not\ use\ this\ file\ except\ in\ compliance\ with\ the\ License\.\n - \#\ You\ may\ obtain\ a\ copy\ of\ the\ License\ at\ - \#\n - \#\s+http:\/\/www\.apache\.org\/licenses\/LICENSE-2\.0\n - \#\n - \#\ Unless\ required\ by\ applicable\ law\ or\ agreed\ to\ in\ writing,\ software\n - \#\ distributed\ under\ the\ License\ is\ distributed\ on\ an\ \"AS\ IS\"\ BASIS,\n - \#\ WITHOUT\ WARRANTIES\ OR\ CONDITIONS\ OF\ ANY\ KIND,\ either\ express\ or\ implied\.\n - \#\ See\ the\ License\ for\ the\ specific\ language\ governing\ permissions\ and\n - \#\ limitations\ under\ the\ License\. - """ - pattern = re.compile(license_regex, re.VERBOSE) - - with open(sync_file.output_path, "r") as f: - content = f.read() - assert pattern.search(content), "Missing license header" From b245b78cd38913e1859d8e506931af058e748893 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 27 Nov 2024 16:57:53 -0800 Subject: [PATCH 355/360] removed missing lines --- test_proxy/README.md | 2 +- test_proxy/run_tests.sh | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/test_proxy/README.md b/test_proxy/README.md index e46ed232e..266fba7cd 100644 --- a/test_proxy/README.md +++ b/test_proxy/README.md @@ -31,7 +31,7 @@ python test_proxy.py --port 8080 ``` By default, the test_proxy targets the async client. You can change this by passing in the `--client_type` flag. -Valid options are `async`, and `legacy`. +Valid options are `async` and `legacy`. ``` python test_proxy.py --client_type=legacy diff --git a/test_proxy/run_tests.sh b/test_proxy/run_tests.sh index 68788e3bb..c2e9c6312 100755 --- a/test_proxy/run_tests.sh +++ b/test_proxy/run_tests.sh @@ -45,5 +45,4 @@ trap finish EXIT # run tests pushd cloud-bigtable-clients-test/tests -echo "Running with $TEST_ARGS" go test -v -proxy_addr=:$PROXY_SERVER_PORT From 47921f334609aec578a7866a87c6646e00ad1268 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 27 Nov 2024 16:58:22 -0800 Subject: [PATCH 356/360] Revert "removed missing lines" This reverts commit b245b78cd38913e1859d8e506931af058e748893. --- test_proxy/README.md | 2 +- test_proxy/run_tests.sh | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test_proxy/README.md b/test_proxy/README.md index 266fba7cd..e46ed232e 100644 --- a/test_proxy/README.md +++ b/test_proxy/README.md @@ -31,7 +31,7 @@ python test_proxy.py --port 8080 ``` By default, the test_proxy targets the async client. You can change this by passing in the `--client_type` flag. -Valid options are `async` and `legacy`. +Valid options are `async`, and `legacy`. ``` python test_proxy.py --client_type=legacy diff --git a/test_proxy/run_tests.sh b/test_proxy/run_tests.sh index c2e9c6312..68788e3bb 100755 --- a/test_proxy/run_tests.sh +++ b/test_proxy/run_tests.sh @@ -45,4 +45,5 @@ trap finish EXIT # run tests pushd cloud-bigtable-clients-test/tests +echo "Running with $TEST_ARGS" go test -v -proxy_addr=:$PROXY_SERVER_PORT From dfb3dc1eec4b0619bfe91b627b53d8cc6024d886 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 27 Nov 2024 16:58:39 -0800 Subject: [PATCH 357/360] Revert "remove new tests from this PR" This reverts commit 61707ca874dbaa2785e4b321d903ebaf463ebbee. --- .cross_sync/README.md | 2 + .github/workflows/conformance.yaml | 12 +- .kokoro/conformance.sh | 10 +- noxfile.py | 2 +- test_proxy/README.md | 2 +- .../client_handler_data_sync_autogen.py | 185 ++++++++++++++++++ test_proxy/run_tests.sh | 16 +- test_proxy/test_proxy.py | 5 +- tests/unit/data/test_sync_up_to_date.py | 99 ++++++++++ 9 files changed, 318 insertions(+), 15 deletions(-) create mode 100644 test_proxy/handlers/client_handler_data_sync_autogen.py create mode 100644 tests/unit/data/test_sync_up_to_date.py diff --git a/.cross_sync/README.md b/.cross_sync/README.md index 18a9aafdf..0d8a1cf8c 100644 --- a/.cross_sync/README.md +++ b/.cross_sync/README.md @@ -66,6 +66,8 @@ Generation can be initiated using `nox -s generate_sync` from the root of the project. This will find all classes with the `__CROSS_SYNC_OUTPUT__ = "path/to/output"` annotation, and generate a sync version of classes marked with `@CrossSync.convert_sync` at the output path. +There is a unit test at `tests/unit/data/test_sync_up_to_date.py` that verifies that the generated code is up to date + ## Architecture CrossSync is made up of two parts: diff --git a/.github/workflows/conformance.yaml b/.github/workflows/conformance.yaml index 448e1cc3a..d4e992c8d 100644 --- a/.github/workflows/conformance.yaml +++ b/.github/workflows/conformance.yaml @@ -26,7 +26,15 @@ jobs: matrix: test-version: [ "v0.0.2" ] py-version: [ 3.8 ] - client-type: [ "async", "legacy" ] + client-type: [ "async", "sync", "legacy" ] + include: + - client-type: "sync" + # sync client does not support concurrent streams + test_args: "-skip _Generic_MultiStream" + - client-type: "legacy" + # legacy client is synchtonous and does not support concurrent streams + # legacy client does not expose mutate_row. Disable those tests + test_args: "-skip _Generic_MultiStream -skip TestMutateRow_" fail-fast: false name: "${{ matrix.client-type }} client / python ${{ matrix.py-version }} / test tag ${{ matrix.test-version }}" steps: @@ -53,4 +61,6 @@ jobs: env: CLIENT_TYPE: ${{ matrix.client-type }} PYTHONUNBUFFERED: 1 + TEST_ARGS: ${{ matrix.test_args }} + PROXY_PORT: 9999 diff --git a/.kokoro/conformance.sh b/.kokoro/conformance.sh index e85fc1394..fd585142e 100644 --- a/.kokoro/conformance.sh +++ b/.kokoro/conformance.sh @@ -19,16 +19,7 @@ set -eo pipefail ## cd to the parent directory, i.e. the root of the git repo cd $(dirname $0)/.. -PROXY_ARGS="" -TEST_ARGS="" -if [[ "${CLIENT_TYPE^^}" == "LEGACY" ]]; then - echo "Using legacy client" - # legacy client does not expose mutate_row. Disable those tests - TEST_ARGS="-skip TestMutateRow_" -fi - # Build and start the proxy in a separate process -PROXY_PORT=9999 pushd test_proxy nohup python test_proxy.py --port $PROXY_PORT --client_type=$CLIENT_TYPE & proxyPID=$! @@ -42,6 +33,7 @@ function cleanup() { trap cleanup EXIT # Run the conformance test +echo "running tests with args: $TEST_ARGS" pushd cloud-bigtable-clients-test/tests eval "go test -v -proxy_addr=:$PROXY_PORT $TEST_ARGS" RETURN_CODE=$? diff --git a/noxfile.py b/noxfile.py index 8576fed85..548bfd0ec 100644 --- a/noxfile.py +++ b/noxfile.py @@ -298,7 +298,7 @@ def system_emulated(session): @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) -@nox.parametrize("client_type", ["async"]) +@nox.parametrize("client_type", ["async", "sync", "legacy"]) def conformance(session, client_type): # install dependencies constraints_path = str( diff --git a/test_proxy/README.md b/test_proxy/README.md index e46ed232e..5c87c729a 100644 --- a/test_proxy/README.md +++ b/test_proxy/README.md @@ -31,7 +31,7 @@ python test_proxy.py --port 8080 ``` By default, the test_proxy targets the async client. You can change this by passing in the `--client_type` flag. -Valid options are `async`, and `legacy`. +Valid options are `async`, `sync`, and `legacy`. ``` python test_proxy.py --client_type=legacy diff --git a/test_proxy/handlers/client_handler_data_sync_autogen.py b/test_proxy/handlers/client_handler_data_sync_autogen.py new file mode 100644 index 000000000..52ddec6fd --- /dev/null +++ b/test_proxy/handlers/client_handler_data_sync_autogen.py @@ -0,0 +1,185 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is automatically generated by CrossSync. Do not edit manually. + +""" +This module contains the client handler process for proxy_server.py. +""" +import os +from google.cloud.environment_vars import BIGTABLE_EMULATOR +from google.cloud.bigtable.data._cross_sync import CrossSync +from client_handler_data_async import error_safe + + +class TestProxyClientHandler: + """ + Implements the same methods as the grpc server, but handles the client + library side of the request. + + Requests received in TestProxyGrpcServer are converted to a dictionary, + and supplied to the TestProxyClientHandler methods as kwargs. + The client response is then returned back to the TestProxyGrpcServer + """ + + def __init__( + self, + data_target=None, + project_id=None, + instance_id=None, + app_profile_id=None, + per_operation_timeout=None, + **kwargs + ): + self.closed = False + os.environ[BIGTABLE_EMULATOR] = data_target + self.client = CrossSync._Sync_Impl.DataClient(project=project_id) + self.instance_id = instance_id + self.app_profile_id = app_profile_id + self.per_operation_timeout = per_operation_timeout + + def close(self): + self.closed = True + + @error_safe + async def ReadRows(self, request, **kwargs): + table_id = request.pop("table_name").split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + result_list = table.read_rows(request, **kwargs) + serialized_response = [row._to_dict() for row in result_list] + return serialized_response + + @error_safe + async def ReadRow(self, row_key, **kwargs): + table_id = kwargs.pop("table_name").split("/")[-1] + app_profile_id = self.app_profile_id or kwargs.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + result_row = table.read_row(row_key, **kwargs) + if result_row: + return result_row._to_dict() + else: + return "None" + + @error_safe + async def MutateRow(self, request, **kwargs): + from google.cloud.bigtable.data.mutations import Mutation + + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + row_key = request["row_key"] + mutations = [Mutation._from_dict(d) for d in request["mutations"]] + table.mutate_row(row_key, mutations, **kwargs) + return "OK" + + @error_safe + async def BulkMutateRows(self, request, **kwargs): + from google.cloud.bigtable.data.mutations import RowMutationEntry + + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + entry_list = [ + RowMutationEntry._from_dict(entry) for entry in request["entries"] + ] + table.bulk_mutate_rows(entry_list, **kwargs) + return "OK" + + @error_safe + async def CheckAndMutateRow(self, request, **kwargs): + from google.cloud.bigtable.data.mutations import Mutation, SetCell + + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + row_key = request["row_key"] + true_mutations = [] + for mut_dict in request.get("true_mutations", []): + try: + true_mutations.append(Mutation._from_dict(mut_dict)) + except ValueError: + mutation = SetCell("", "", "", 0) + true_mutations.append(mutation) + false_mutations = [] + for mut_dict in request.get("false_mutations", []): + try: + false_mutations.append(Mutation._from_dict(mut_dict)) + except ValueError: + false_mutations.append(SetCell("", "", "", 0)) + predicate_filter = request.get("predicate_filter", None) + result = table.check_and_mutate_row( + row_key, + predicate_filter, + true_case_mutations=true_mutations, + false_case_mutations=false_mutations, + **kwargs + ) + return result + + @error_safe + async def ReadModifyWriteRow(self, request, **kwargs): + from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule + from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule + + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + row_key = request["row_key"] + rules = [] + for rule_dict in request.get("rules", []): + qualifier = rule_dict["column_qualifier"] + if "append_value" in rule_dict: + new_rule = AppendValueRule( + rule_dict["family_name"], qualifier, rule_dict["append_value"] + ) + else: + new_rule = IncrementRule( + rule_dict["family_name"], qualifier, rule_dict["increment_amount"] + ) + rules.append(new_rule) + result = table.read_modify_write_row(row_key, rules, **kwargs) + if result: + return result._to_dict() + else: + return "None" + + @error_safe + async def SampleRowKeys(self, request, **kwargs): + table_id = request["table_name"].split("/")[-1] + app_profile_id = self.app_profile_id or request.get("app_profile_id", None) + table = self.client.get_table(self.instance_id, table_id, app_profile_id) + kwargs["operation_timeout"] = ( + kwargs.get("operation_timeout", self.per_operation_timeout) or 20 + ) + result = table.sample_row_keys(**kwargs) + return result diff --git a/test_proxy/run_tests.sh b/test_proxy/run_tests.sh index 68788e3bb..b6f1291a6 100755 --- a/test_proxy/run_tests.sh +++ b/test_proxy/run_tests.sh @@ -27,7 +27,7 @@ fi SCRIPT_DIR=$(realpath $(dirname "$0")) cd $SCRIPT_DIR -export PROXY_SERVER_PORT=50055 +export PROXY_SERVER_PORT=$(shuf -i 50000-60000 -n 1) # download test suite if [ ! -d "cloud-bigtable-clients-test" ]; then @@ -43,7 +43,19 @@ function finish { } trap finish EXIT +if [[ $CLIENT_TYPE == "legacy" ]]; then + echo "Using legacy client" + # legacy client does not expose mutate_row. Disable those tests + TEST_ARGS="-skip TestMutateRow_" +fi + +if [[ $CLIENT_TYPE != "async" ]]; then + echo "Using legacy client" + # sync and legacy client do not support concurrent streams + TEST_ARGS="$TEST_ARGS -skip _Generic_MultiStream " +fi + # run tests pushd cloud-bigtable-clients-test/tests echo "Running with $TEST_ARGS" -go test -v -proxy_addr=:$PROXY_SERVER_PORT +go test -v -proxy_addr=:$PROXY_SERVER_PORT $TEST_ARGS diff --git a/test_proxy/test_proxy.py b/test_proxy/test_proxy.py index 9e03f1e5c..793500768 100644 --- a/test_proxy/test_proxy.py +++ b/test_proxy/test_proxy.py @@ -114,6 +114,9 @@ def format_dict(input_obj): if client_type == "legacy": import client_handler_legacy client = client_handler_legacy.LegacyTestProxyClientHandler(**json_data) + elif client_type == "sync": + import client_handler_data_sync_autogen + client = client_handler_data_sync_autogen.TestProxyClientHandler(**json_data) else: client = client_handler_data_async.TestProxyClientHandlerAsync(**json_data) client_map[client_id] = client @@ -150,7 +153,7 @@ def client_handler_process(request_q, queue_pool, client_type="async"): p = argparse.ArgumentParser() p.add_argument("--port", dest='port', default="50055") -p.add_argument("--client_type", dest='client_type', default="async", choices=["async", "legacy"]) +p.add_argument("--client_type", dest='client_type', default="async", choices=["async", "sync", "legacy"]) if __name__ == "__main__": port = p.parse_args().port diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py new file mode 100644 index 000000000..492d35ddf --- /dev/null +++ b/tests/unit/data/test_sync_up_to_date.py @@ -0,0 +1,99 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import hashlib +import pytest +import ast +import re +from difflib import unified_diff + +# add cross_sync to path +test_dir_name = os.path.dirname(__file__) +repo_root = os.path.join(test_dir_name, "..", "..", "..") +cross_sync_path = os.path.join(repo_root, ".cross_sync") +sys.path.append(cross_sync_path) + +from generate import convert_files_in_dir, CrossSyncOutputFile # noqa: E402 + +sync_files = list(convert_files_in_dir(repo_root)) + + +def test_found_files(): + """ + Make sure sync_test is populated with some of the files we expect to see, + to ensure that later tests are actually running. + """ + assert len(sync_files) > 0, "No sync files found" + assert len(sync_files) > 10, "Unexpectedly few sync files found" + # test for key files + outputs = [os.path.basename(f.output_path) for f in sync_files] + assert "client.py" in outputs + assert "execute_query_iterator.py" in outputs + assert "test_client.py" in outputs + assert "test_system_autogen.py" in outputs, "system tests not found" + assert ( + "client_handler_data_sync_autogen.py" in outputs + ), "test proxy handler not found" + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="ast.unparse is only available in 3.9+" +) +@pytest.mark.parametrize("sync_file", sync_files, ids=lambda f: f.output_path) +def test_sync_up_to_date(sync_file): + """ + Generate a fresh copy of each cross_sync file, and compare hashes with the existing file. + + If this test fails, run `nox -s generate_sync` to update the sync files. + """ + path = sync_file.output_path + new_render = sync_file.render(with_formatter=True, save_to_disk=False) + found_render = CrossSyncOutputFile( + output_path="", ast_tree=ast.parse(open(path).read()), header=sync_file.header + ).render(with_formatter=True, save_to_disk=False) + # compare by content + diff = unified_diff(found_render.splitlines(), new_render.splitlines(), lineterm="") + diff_str = "\n".join(diff) + assert ( + not diff_str + ), f"Found differences. Run `nox -s generate_sync` to update:\n{diff_str}" + # compare by hash + new_hash = hashlib.md5(new_render.encode()).hexdigest() + found_hash = hashlib.md5(found_render.encode()).hexdigest() + assert new_hash == found_hash, f"md5 mismatch for {path}" + + +@pytest.mark.parametrize("sync_file", sync_files, ids=lambda f: f.output_path) +def test_verify_headers(sync_file): + license_regex = r""" + \#\ Copyright\ \d{4}\ Google\ LLC\n + \#\n + \#\ Licensed\ under\ the\ Apache\ License,\ Version\ 2\.0\ \(the\ \"License\"\);\n + \#\ you\ may\ not\ use\ this\ file\ except\ in\ compliance\ with\ the\ License\.\n + \#\ You\ may\ obtain\ a\ copy\ of\ the\ License\ at\ + \#\n + \#\s+http:\/\/www\.apache\.org\/licenses\/LICENSE-2\.0\n + \#\n + \#\ Unless\ required\ by\ applicable\ law\ or\ agreed\ to\ in\ writing,\ software\n + \#\ distributed\ under\ the\ License\ is\ distributed\ on\ an\ \"AS\ IS\"\ BASIS,\n + \#\ WITHOUT\ WARRANTIES\ OR\ CONDITIONS\ OF\ ANY\ KIND,\ either\ express\ or\ implied\.\n + \#\ See\ the\ License\ for\ the\ specific\ language\ governing\ permissions\ and\n + \#\ limitations\ under\ the\ License\. + """ + pattern = re.compile(license_regex, re.VERBOSE) + + with open(sync_file.output_path, "r") as f: + content = f.read() + assert pattern.search(content), "Missing license header" From 6a456c1ef4ff5c504ca8558a657dd854e4d412c6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Dec 2024 14:06:54 -0800 Subject: [PATCH 358/360] removed duplicate files --- docs/data_client/async_client.rst | 12 ---- .../async_execute_query_iterator.rst | 6 -- docs/data_client/async_mutations_batcher.rst | 6 -- docs/data_client/async_table.rst | 11 ---- docs/data_client/client.rst | 6 -- docs/data_client/exceptions.rst | 6 -- docs/data_client/execute_query_iterator.rst | 6 -- docs/data_client/execute_query_metadata.rst | 6 -- docs/data_client/execute_query_values.rst | 6 -- docs/data_client/mutations.rst | 6 -- docs/data_client/mutations_batcher.rst | 6 -- docs/data_client/read_modify_write_rules.rst | 6 -- docs/data_client/read_rows_query.rst | 6 -- docs/data_client/row.rst | 6 -- docs/data_client/row_filters.rst | 62 ------------------- docs/data_client/table.rst | 6 -- docs/data_client/usage.rst | 39 ------------ 17 files changed, 202 deletions(-) delete mode 100644 docs/data_client/async_client.rst delete mode 100644 docs/data_client/async_execute_query_iterator.rst delete mode 100644 docs/data_client/async_mutations_batcher.rst delete mode 100644 docs/data_client/async_table.rst delete mode 100644 docs/data_client/client.rst delete mode 100644 docs/data_client/exceptions.rst delete mode 100644 docs/data_client/execute_query_iterator.rst delete mode 100644 docs/data_client/execute_query_metadata.rst delete mode 100644 docs/data_client/execute_query_values.rst delete mode 100644 docs/data_client/mutations.rst delete mode 100644 docs/data_client/mutations_batcher.rst delete mode 100644 docs/data_client/read_modify_write_rules.rst delete mode 100644 docs/data_client/read_rows_query.rst delete mode 100644 docs/data_client/row.rst delete mode 100644 docs/data_client/row_filters.rst delete mode 100644 docs/data_client/table.rst delete mode 100644 docs/data_client/usage.rst diff --git a/docs/data_client/async_client.rst b/docs/data_client/async_client.rst deleted file mode 100644 index 2ddcc090c..000000000 --- a/docs/data_client/async_client.rst +++ /dev/null @@ -1,12 +0,0 @@ -Bigtable Data Client Async -~~~~~~~~~~~~~~~~~~~~~~~~~~ - - .. note:: - - It is generally not recommended to use the async client in an otherwise synchronous codebase. To make use of asyncio's - performance benefits, the codebase should be designed to be async from the ground up. - - -.. autoclass:: google.cloud.bigtable.data.BigtableDataClientAsync - :members: - :show-inheritance: diff --git a/docs/data_client/async_execute_query_iterator.rst b/docs/data_client/async_execute_query_iterator.rst deleted file mode 100644 index b911fab7f..000000000 --- a/docs/data_client/async_execute_query_iterator.rst +++ /dev/null @@ -1,6 +0,0 @@ -Execute Query Iterator Async -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: google.cloud.bigtable.data.execute_query.ExecuteQueryIteratorAsync - :members: - :show-inheritance: diff --git a/docs/data_client/async_mutations_batcher.rst b/docs/data_client/async_mutations_batcher.rst deleted file mode 100644 index 3e81f885a..000000000 --- a/docs/data_client/async_mutations_batcher.rst +++ /dev/null @@ -1,6 +0,0 @@ -Mutations Batcher Async -~~~~~~~~~~~~~~~~~~~~~~~ - -.. automodule:: google.cloud.bigtable.data._async.mutations_batcher - :members: - :show-inheritance: diff --git a/docs/data_client/async_table.rst b/docs/data_client/async_table.rst deleted file mode 100644 index 3b7973e8e..000000000 --- a/docs/data_client/async_table.rst +++ /dev/null @@ -1,11 +0,0 @@ -Table Async -~~~~~~~~~~~ - - .. note:: - - It is generally not recommended to use the async client in an otherwise synchronous codebase. To make use of asyncio's - performance benefits, the codebase should be designed to be async from the ground up. - -.. autoclass:: google.cloud.bigtable.data._async.client.TableAsync - :members: - :show-inheritance: diff --git a/docs/data_client/client.rst b/docs/data_client/client.rst deleted file mode 100644 index cf7c00dad..000000000 --- a/docs/data_client/client.rst +++ /dev/null @@ -1,6 +0,0 @@ -Bigtable Data Client -~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: google.cloud.bigtable.data.BigtableDataClient - :members: - :show-inheritance: diff --git a/docs/data_client/exceptions.rst b/docs/data_client/exceptions.rst deleted file mode 100644 index 6180ef222..000000000 --- a/docs/data_client/exceptions.rst +++ /dev/null @@ -1,6 +0,0 @@ -Custom Exceptions -~~~~~~~~~~~~~~~~~ - -.. automodule:: google.cloud.bigtable.data.exceptions - :members: - :show-inheritance: diff --git a/docs/data_client/execute_query_iterator.rst b/docs/data_client/execute_query_iterator.rst deleted file mode 100644 index 6eb9f84db..000000000 --- a/docs/data_client/execute_query_iterator.rst +++ /dev/null @@ -1,6 +0,0 @@ -Execute Query Iterator -~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: google.cloud.bigtable.data.execute_query.ExecuteQueryIterator - :members: - :show-inheritance: diff --git a/docs/data_client/execute_query_metadata.rst b/docs/data_client/execute_query_metadata.rst deleted file mode 100644 index 69add630d..000000000 --- a/docs/data_client/execute_query_metadata.rst +++ /dev/null @@ -1,6 +0,0 @@ -Execute Query Metadata -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. automodule:: google.cloud.bigtable.data.execute_query.metadata - :members: - :show-inheritance: diff --git a/docs/data_client/execute_query_values.rst b/docs/data_client/execute_query_values.rst deleted file mode 100644 index 6c4fb71c1..000000000 --- a/docs/data_client/execute_query_values.rst +++ /dev/null @@ -1,6 +0,0 @@ -Execute Query Values -~~~~~~~~~~~~~~~~~~~~ - -.. automodule:: google.cloud.bigtable.data.execute_query.values - :members: - :show-inheritance: diff --git a/docs/data_client/mutations.rst b/docs/data_client/mutations.rst deleted file mode 100644 index 9d7a9eab2..000000000 --- a/docs/data_client/mutations.rst +++ /dev/null @@ -1,6 +0,0 @@ -Mutations -~~~~~~~~~ - -.. automodule:: google.cloud.bigtable.data.mutations - :members: - :show-inheritance: diff --git a/docs/data_client/mutations_batcher.rst b/docs/data_client/mutations_batcher.rst deleted file mode 100644 index 2b7d1bfe0..000000000 --- a/docs/data_client/mutations_batcher.rst +++ /dev/null @@ -1,6 +0,0 @@ -Mutations Batcher -~~~~~~~~~~~~~~~~~ - -.. automodule:: google.cloud.bigtable.data._sync_autogen.mutations_batcher - :members: - :show-inheritance: diff --git a/docs/data_client/read_modify_write_rules.rst b/docs/data_client/read_modify_write_rules.rst deleted file mode 100644 index 2f28ddf3f..000000000 --- a/docs/data_client/read_modify_write_rules.rst +++ /dev/null @@ -1,6 +0,0 @@ -Read Modify Write Rules -~~~~~~~~~~~~~~~~~~~~~~~ - -.. automodule:: google.cloud.bigtable.data.read_modify_write_rules - :members: - :show-inheritance: diff --git a/docs/data_client/read_rows_query.rst b/docs/data_client/read_rows_query.rst deleted file mode 100644 index 4e3e796d9..000000000 --- a/docs/data_client/read_rows_query.rst +++ /dev/null @@ -1,6 +0,0 @@ -Read Rows Query -~~~~~~~~~~~~~~~ - -.. automodule:: google.cloud.bigtable.data.read_rows_query - :members: - :show-inheritance: diff --git a/docs/data_client/row.rst b/docs/data_client/row.rst deleted file mode 100644 index 63bc71143..000000000 --- a/docs/data_client/row.rst +++ /dev/null @@ -1,6 +0,0 @@ -Rows and Cells -~~~~~~~~~~~~~~ - -.. automodule:: google.cloud.bigtable.data.row - :members: - :show-inheritance: diff --git a/docs/data_client/row_filters.rst b/docs/data_client/row_filters.rst deleted file mode 100644 index 22bda8a26..000000000 --- a/docs/data_client/row_filters.rst +++ /dev/null @@ -1,62 +0,0 @@ -Bigtable Row Filters -==================== - -It is possible to use a -:class:`RowFilter ` -when constructing a :class:`ReadRowsQuery ` - -The following basic filters -are provided: - -* :class:`SinkFilter <.data.row_filters.SinkFilter>` -* :class:`PassAllFilter <.data.row_filters.PassAllFilter>` -* :class:`BlockAllFilter <.data.row_filters.BlockAllFilter>` -* :class:`RowKeyRegexFilter <.data.row_filters.RowKeyRegexFilter>` -* :class:`RowSampleFilter <.data.row_filters.RowSampleFilter>` -* :class:`FamilyNameRegexFilter <.data.row_filters.FamilyNameRegexFilter>` -* :class:`ColumnQualifierRegexFilter <.data.row_filters.ColumnQualifierRegexFilter>` -* :class:`TimestampRangeFilter <.data.row_filters.TimestampRangeFilter>` -* :class:`ColumnRangeFilter <.data.row_filters.ColumnRangeFilter>` -* :class:`ValueRegexFilter <.data.row_filters.ValueRegexFilter>` -* :class:`ValueRangeFilter <.data.row_filters.ValueRangeFilter>` -* :class:`CellsRowOffsetFilter <.data.row_filters.CellsRowOffsetFilter>` -* :class:`CellsRowLimitFilter <.data.row_filters.CellsRowLimitFilter>` -* :class:`CellsColumnLimitFilter <.data.row_filters.CellsColumnLimitFilter>` -* :class:`StripValueTransformerFilter <.data.row_filters.StripValueTransformerFilter>` -* :class:`ApplyLabelFilter <.data.row_filters.ApplyLabelFilter>` - -In addition, these filters can be combined into composite filters with - -* :class:`RowFilterChain <.data.row_filters.RowFilterChain>` -* :class:`RowFilterUnion <.data.row_filters.RowFilterUnion>` -* :class:`ConditionalRowFilter <.data.row_filters.ConditionalRowFilter>` - -These rules can be nested arbitrarily, with a basic filter at the lowest -level. For example: - -.. code:: python - - # Filter in a specified column (matching any column family). - col1_filter = ColumnQualifierRegexFilter(b'columnbia') - - # Create a filter to label results. - label1 = u'label-red' - label1_filter = ApplyLabelFilter(label1) - - # Combine the filters to label all the cells in columnbia. - chain1 = RowFilterChain(filters=[col1_filter, label1_filter]) - - # Create a similar filter to label cells blue. - col2_filter = ColumnQualifierRegexFilter(b'columnseeya') - label2 = u'label-blue' - label2_filter = ApplyLabelFilter(label2) - chain2 = RowFilterChain(filters=[col2_filter, label2_filter]) - - # Bring our two labeled columns together. - row_filter = RowFilterUnion(filters=[chain1, chain2]) - ----- - -.. automodule:: google.cloud.bigtable.data.row_filters - :members: - :show-inheritance: diff --git a/docs/data_client/table.rst b/docs/data_client/table.rst deleted file mode 100644 index 95c91eb27..000000000 --- a/docs/data_client/table.rst +++ /dev/null @@ -1,6 +0,0 @@ -Table -~~~~~ - -.. autoclass:: google.cloud.bigtable.data.Table - :members: - :show-inheritance: diff --git a/docs/data_client/usage.rst b/docs/data_client/usage.rst deleted file mode 100644 index 8edc424ec..000000000 --- a/docs/data_client/usage.rst +++ /dev/null @@ -1,39 +0,0 @@ -Data Client -=========== - -Sync Surface ------------- - -.. toctree:: - :maxdepth: 3 - - client - table - mutations_batcher - execute_query_iterator - -Async Surface -------------- - -.. toctree:: - :maxdepth: 3 - - async_client - async_table - async_mutations_batcher - async_execute_query_iterator - -Common Classes --------------- - -.. toctree:: - :maxdepth: 3 - - read_rows_query - row - row_filters - mutations - read_modify_write_rules - exceptions - execute_query_values - execute_query_metadata From dbe86f13d094abdebc4ad1368f99f61b2c434333 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Dec 2024 14:09:01 -0800 Subject: [PATCH 359/360] regenerated files --- test_proxy/handlers/client_handler_data_sync_autogen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_proxy/handlers/client_handler_data_sync_autogen.py b/test_proxy/handlers/client_handler_data_sync_autogen.py index 52ddec6fd..eabae0ffa 100644 --- a/test_proxy/handlers/client_handler_data_sync_autogen.py +++ b/test_proxy/handlers/client_handler_data_sync_autogen.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 710190bf275b8951ff5d679f22964eea0e1bed65 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Dec 2024 10:10:40 -0800 Subject: [PATCH 360/360] fix typo Co-authored-by: Mattie Fu --- .github/workflows/conformance.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/conformance.yaml b/.github/workflows/conformance.yaml index d4e992c8d..8445240c3 100644 --- a/.github/workflows/conformance.yaml +++ b/.github/workflows/conformance.yaml @@ -32,7 +32,7 @@ jobs: # sync client does not support concurrent streams test_args: "-skip _Generic_MultiStream" - client-type: "legacy" - # legacy client is synchtonous and does not support concurrent streams + # legacy client is synchronous and does not support concurrent streams # legacy client does not expose mutate_row. Disable those tests test_args: "-skip _Generic_MultiStream -skip TestMutateRow_" fail-fast: false